In [1]:
!git clone https://github.com/sakanaowo/PlantXViT

Cloning into 'PlantXViT'...
remote: Enumerating objects: 50377, done.[K
remote: Counting objects: 100% (106/106), done.[K
remote: Compressing objects: 100% (73/73), done.[K
remote: Total 50377 (delta 42), reused 83 (delta 21), pack-reused 50271 (from 1)[K
Receiving objects: 100% (50377/50377), 1.66 GiB | 49.29 MiB/s, done.
Resolving deltas: 100% (30427/30427), done.
Updating files: 100% (50038/50038), done.


In [2]:
!git pull

fatal: not a git repository (or any parent up to mount point /kaggle)
Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).


In [3]:
%cd PlantXViT

/kaggle/working/PlantXViT


In [4]:
!ls

configs  data  notebooks  outputs  requirements.txt  src  utils


preprocess here

In [5]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

image_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.0),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


In [6]:
!ls ./data/raw/embrapa

test  train  val


In [7]:
import os

In [8]:
root_dir="./data/raw/embrapa"

In [9]:
train_dataset = datasets.ImageFolder(os.path.join(root_dir, "train"), transform=image_transforms)
val_dataset = datasets.ImageFolder(os.path.join(root_dir, "val"), transform=image_transforms)
test_dataset = datasets.ImageFolder(os.path.join(root_dir, "test"), transform=image_transforms)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)

In [10]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import VGG16_Weights


# inception block (chỉnh sửa: tổng output channels = 512)
# class InceptionBlock(nn.Module):
#     def __init__(self, in_channels):
#         super().__init__()
#         self.branch1x1 = nn.Conv2d(in_channels, 192, kernel_size=1)  # Tăng lên 192

#         self.branch3x3 = nn.Sequential(
#             nn.Conv2d(in_channels, 160, kernel_size=(1, 3), padding=(0, 1)),
#             nn.Conv2d(160, 160, kernel_size=(3, 1), padding=(1, 0)),
#         )

#         self.branch_pool = nn.Sequential(
#             nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
#             nn.Conv2d(in_channels, 160, kernel_size=1),  # Tăng lên 160
#         )

#     def forward(self, x):
#         b1 = self.branch1x1(x)
#         b2 = self.branch3x3(x)
#         b3 = self.branch_pool(x)
#         return torch.cat([b1, b2, b3], dim=1)  # Output shape: (B, 512, H, W)

# inception refine
class InceptionBlock(nn.Module):
    def __init__(self, in_channels=128):
        super(InceptionBlock, self).__init__()
        # Nhánh 1: 1x1
        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_channels, 128, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm2d(128)
        )

        # Nhánh 2: 1x1 -> 3x1 + 1x3
        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, 96, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm2d(96),
            nn.Conv2d(96, 128, kernel_size=(3, 1), padding=(1, 0)),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, kernel_size=(1, 3), padding=(0, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(128)
        )

        # Nhánh 3: 1x1 -> 3x1 + 1x3 -> 3x1 + 1x3
        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 96, kernel_size=(3, 1), padding=(1, 0)),
            nn.ReLU(),
            nn.BatchNorm2d(96),
            nn.Conv2d(96, 96, kernel_size=(1, 3), padding=(0, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(96),
            nn.Conv2d(96, 192, kernel_size=(3, 1), padding=(1, 0)),
            nn.ReLU(),
            nn.BatchNorm2d(192),
            nn.Conv2d(192, 192, kernel_size=(1, 3), padding=(0, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(192)
        )

        # Nhánh 4: MaxPool -> 1x1
        self.branch_pool = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels, 64, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm2d(64)
        )

    def forward(self, x):
        b1 = self.branch1x1(x)
        b2 = self.branch3x3(x)
        b3 = self.branch5x5(x)
        b4 = self.branch_pool(x)
        return torch.cat([b1, b2, b3, b4], dim=1)


# patch embedding: split patch -> Linear
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, patch_size=5, emb_size=16):
        super().__init__()
        self.patch_size = patch_size
        self.emb_size = emb_size
        self.proj = nn.Linear(in_channels * patch_size * patch_size, emb_size)

    def forward(self, x):
        B, C, H, W = x.shape
        x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        x = x.permute(0, 2, 3, 1, 4, 5).contiguous()
        x = x.view(B, -1, C * self.patch_size * self.patch_size)
        return self.proj(x)  # shape: (b,num patches,emb size)


# -------- Transformer Encoder Block (ViT block) --------
class TransformerBlock(nn.Module):
    def __init__(self, emb_size=16, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(emb_size)
        self.attn = nn.MultiheadAttention(emb_size, num_heads=2, batch_first=True)
        self.norm2 = nn.LayerNorm(emb_size)
        self.mlp = nn.Sequential(
            nn.Linear(emb_size, emb_size * 2),
            nn.GELU(),
            nn.Linear(emb_size * 2, emb_size),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x_attn, _ = self.attn(self.norm1(x), self.norm1(x), self.norm1(x))
        x = x + x_attn
        x = x + self.mlp(self.norm2(x))
        return x


# -------- PlantXViT Model --------
class PlantXViT(nn.Module):
    def __init__(self, num_classes=4, patch_size=5, emb_size=16, num_blocks=4, dropout=0.1):
        super().__init__()

        # VGG16 (2 blocks)
        vgg = models.vgg16(weights=VGG16_Weights.DEFAULT)
        self.vgg_block = nn.Sequential(*list(vgg.features[:10]))
        # self.vgg_block = nn.Sequential(*vgg[:10])  # output: (B, 128, 56, 56)

        # Inception-like block → (B, 384, 56, 56)
        self.inception = InceptionBlock(in_channels=128)

        # Patch Embedding → (B, 121, 16)
        # self.patch_embed = PatchEmbedding(in_channels=384, patch_size=patch_size, emb_size=emb_size)
        self.patch_embed = PatchEmbedding(in_channels=512, patch_size=patch_size, emb_size=emb_size)

        # Transformer blocks
        self.transformer = nn.Sequential(*[TransformerBlock(emb_size, dropout) for _ in range(num_blocks)])

        # Classification head
        self.norm = nn.LayerNorm(emb_size)
        self.global_pool = nn.AdaptiveAvgPool1d(1)  # (B, emb_size, 1)
        self.classifier = nn.Linear(emb_size, num_classes)

    def forward(self, x):
        x = self.vgg_block(x)  # (B, 128, 56, 56)
        x = self.inception(x)  # (B, 384, 56, 56)
        x = self.patch_embed(x)  # (B, 121, 16)
        x = self.transformer(x)  # (B, 121, 16)
        x = self.norm(x)  # (B, 121, 16)
        x = x.permute(0, 2, 1)  # (B, 16, 121)
        x = self.global_pool(x).squeeze(-1)  # (B, 16)
        return self.classifier(x)  # (B, num_classes)


In [11]:
model = PlantXViT(num_classes=93)
criterion=nn.CrossEntropyLoss()

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:02<00:00, 209MB/s] 


training from here

In [12]:
!cat utils/config_loader.py

import yaml


def load_config(path="../configs/config.yaml"):
    with open(path, "r") as f:
        return yaml.safe_load(f)


In [13]:
from utils.config_loader import load_config
config=load_config('./configs/config.yaml')

In [14]:
print(config['output']['embrapa']['model_path'])

./outputs/embrapa/models/plantxvit_best.pth


In [15]:
import os

In [16]:
!ls outputs/embrapa/models

plantxvit_best_embrapa.pth


In [17]:
DATA_DIR=root_dir
BATCH_SIZE=16
EPOCHS=50
LR=1e-4
NUM_CLASSES=93
DEVICE=torch.device('cuda')
MODEL_PATH = "./outputs/embrapa/models/plantxvit_best.pth"

# Tạo thư mục nếu chưa tồn tại
os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)

In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
from tqdm import tqdm

In [19]:

model.to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)

In [20]:
print(DEVICE)

cuda


In [21]:
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    running_loss, correct, total = 0, 0, 0

    for inputs, labels in tqdm(loader, desc="Training"):
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = running_loss / total
    acc = correct / total
    return avg_loss, acc




In [22]:
def evaluate(model, loader, criterion):
    model.eval()
    running_loss, correct, total = 0, 0, 0

    with torch.no_grad():
        for inputs, labels in tqdm(loader, desc="Evaluating"):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    avg_loss = running_loss / total
    acc = correct / total
    return avg_loss, acc


In [None]:
best_val_acc = 0
patience,wait=5,0

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")

    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_acc = evaluate(model, val_loader, criterion)

    print(f"Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f}")
    print(f"Val   Loss: {val_loss:.4f} | Acc: {val_acc:.4f}")

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), MODEL_PATH)
        print(f"✅ Saved best model to {MODEL_PATH}")
        wait=0
    else:
      wait+=1
      if wait>=patience:
        print(f"Early stopping at epoch {epoch+1}")
        break



Epoch 1/50


Training: 100%|██████████| 1851/1851 [03:22<00:00,  9.12it/s]
Evaluating: 100%|██████████| 466/466 [00:21<00:00, 22.16it/s]


Train Loss: 3.3308 | Acc: 0.3331
Val   Loss: 2.6582 | Acc: 0.4710
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 2/50


Training: 100%|██████████| 1851/1851 [03:23<00:00,  9.11it/s]
Evaluating: 100%|██████████| 466/466 [00:20<00:00, 22.44it/s]


Train Loss: 2.3098 | Acc: 0.4958
Val   Loss: 1.9044 | Acc: 0.5836
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 3/50


Training: 100%|██████████| 1851/1851 [03:22<00:00,  9.14it/s]
Evaluating: 100%|██████████| 466/466 [00:20<00:00, 22.57it/s]


Train Loss: 1.7258 | Acc: 0.6038
Val   Loss: 1.4550 | Acc: 0.6708
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 4/50


Training: 100%|██████████| 1851/1851 [03:21<00:00,  9.18it/s]
Evaluating: 100%|██████████| 466/466 [00:21<00:00, 22.19it/s]


Train Loss: 1.3461 | Acc: 0.6878
Val   Loss: 1.1953 | Acc: 0.7172
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 5/50


Training: 100%|██████████| 1851/1851 [03:21<00:00,  9.19it/s]
Evaluating: 100%|██████████| 466/466 [00:20<00:00, 22.29it/s]


Train Loss: 1.1103 | Acc: 0.7336
Val   Loss: 1.0832 | Acc: 0.7262
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 6/50


Training: 100%|██████████| 1851/1851 [03:21<00:00,  9.18it/s]
Evaluating: 100%|██████████| 466/466 [00:20<00:00, 22.42it/s]


Train Loss: 0.9340 | Acc: 0.7680
Val   Loss: 0.8805 | Acc: 0.7832
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 7/50


Training: 100%|██████████| 1851/1851 [03:22<00:00,  9.13it/s]
Evaluating: 100%|██████████| 466/466 [00:20<00:00, 22.28it/s]


Train Loss: 0.8183 | Acc: 0.7936
Val   Loss: 0.8044 | Acc: 0.7922
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 8/50


Training: 100%|██████████| 1851/1851 [03:23<00:00,  9.11it/s]
Evaluating: 100%|██████████| 466/466 [00:20<00:00, 22.32it/s]


Train Loss: 0.7231 | Acc: 0.8144
Val   Loss: 0.7011 | Acc: 0.8216
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 9/50


Training: 100%|██████████| 1851/1851 [03:22<00:00,  9.16it/s]
Evaluating: 100%|██████████| 466/466 [00:20<00:00, 22.41it/s]


Train Loss: 0.6485 | Acc: 0.8307
Val   Loss: 0.6582 | Acc: 0.8252
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 10/50


Training: 100%|██████████| 1851/1851 [03:23<00:00,  9.12it/s]
Evaluating: 100%|██████████| 466/466 [00:20<00:00, 22.56it/s]


Train Loss: 0.5884 | Acc: 0.8430
Val   Loss: 0.6312 | Acc: 0.8333
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 11/50


Training: 100%|██████████| 1851/1851 [03:21<00:00,  9.18it/s]
Evaluating: 100%|██████████| 466/466 [00:20<00:00, 22.55it/s]


Train Loss: 0.5279 | Acc: 0.8575
Val   Loss: 0.6260 | Acc: 0.8270

Epoch 12/50


Training: 100%|██████████| 1851/1851 [03:23<00:00,  9.10it/s]
Evaluating: 100%|██████████| 466/466 [00:20<00:00, 22.45it/s]


Train Loss: 0.4913 | Acc: 0.8674
Val   Loss: 0.5763 | Acc: 0.8381
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 13/50


Training: 100%|██████████| 1851/1851 [03:21<00:00,  9.18it/s]
Evaluating: 100%|██████████| 466/466 [00:20<00:00, 22.44it/s]


Train Loss: 0.4460 | Acc: 0.8786
Val   Loss: 0.5534 | Acc: 0.8451
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 14/50


Training: 100%|██████████| 1851/1851 [03:21<00:00,  9.19it/s]
Evaluating: 100%|██████████| 466/466 [00:20<00:00, 22.61it/s]


Train Loss: 0.4119 | Acc: 0.8862
Val   Loss: 0.5185 | Acc: 0.8551
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 15/50


Training: 100%|██████████| 1851/1851 [03:22<00:00,  9.14it/s]
Evaluating: 100%|██████████| 466/466 [00:20<00:00, 22.65it/s]


Train Loss: 0.3831 | Acc: 0.8926
Val   Loss: 0.5331 | Acc: 0.8453

Epoch 16/50


Training: 100%|██████████| 1851/1851 [03:21<00:00,  9.18it/s]
Evaluating: 100%|██████████| 466/466 [00:20<00:00, 22.60it/s]


Train Loss: 0.3523 | Acc: 0.9019
Val   Loss: 0.4912 | Acc: 0.8575
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 17/50


Training: 100%|██████████| 1851/1851 [03:21<00:00,  9.18it/s]
Evaluating: 100%|██████████| 466/466 [00:20<00:00, 22.49it/s]


Train Loss: 0.3318 | Acc: 0.9063
Val   Loss: 0.4635 | Acc: 0.8691
✅ Saved best model to ./outputs/embrapa/models/plantxvit_best.pth

Epoch 18/50


Training: 100%|██████████| 1851/1851 [03:21<00:00,  9.17it/s]
Evaluating: 100%|██████████| 466/466 [00:20<00:00, 22.62it/s]


Train Loss: 0.3069 | Acc: 0.9131
Val   Loss: 0.4839 | Acc: 0.8596

Epoch 19/50


Training: 100%|██████████| 1851/1851 [03:21<00:00,  9.21it/s]
Evaluating: 100%|██████████| 466/466 [00:20<00:00, 22.75it/s]


Train Loss: 0.2888 | Acc: 0.9182
Val   Loss: 0.4778 | Acc: 0.8648

Epoch 20/50


Training: 100%|██████████| 1851/1851 [03:22<00:00,  9.12it/s]
Evaluating: 100%|██████████| 466/466 [00:20<00:00, 22.58it/s]


Train Loss: 0.2719 | Acc: 0.9225
Val   Loss: 0.4925 | Acc: 0.8585

Epoch 21/50


Training: 100%|██████████| 1851/1851 [03:21<00:00,  9.20it/s]
Evaluating: 100%|██████████| 466/466 [00:20<00:00, 22.71it/s]


Train Loss: 0.2598 | Acc: 0.9273
Val   Loss: 0.4744 | Acc: 0.8652

Epoch 22/50


Training:  25%|██▍       | 457/1851 [00:50<02:29,  9.32it/s]