In [1]:
import os

import pandas as pd
import tqdm
from skimage import io
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import wandb

from imgcl.config import TRAIN_DIR, TRAIN_LABELS_PATH, IDX_SIZE, \
                        ID_COLUMN, LABEL_COLUMN, TRAIN_SIZE, CHECKPOINT_DIR

In [2]:
!export CUDA_VISIBLE_DEVICES=2

In [3]:
config = {
    "lr": 1e-4,
    "epochs_num": 20,
    "log_each": 50,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "train_batch_size": 64,
    "val_batch_size": 1024
}

In [4]:
# !pip install ../

In [5]:
df = pd.read_csv(TRAIN_LABELS_PATH)

In [6]:
class ImageDataset(Dataset):
    def __init__(self, data_dir, labels_path):
        self.data_dir = data_dir
        labels = pd.read_csv(labels_path)
        labels.set_index([ID_COLUMN], inplace=True)
        self.labels = dict(zip(labels.index, labels[LABEL_COLUMN].values))
        
    def __len__(self):
        return len(os.listdir(self.data_dir)) - 1

    def __getitem__(self, idx):
        idx = self._transform_idx(idx)
        img_path = os.path.join(self.data_dir, idx)
        img = io.imread(img_path)
        img_tensor = torch.from_numpy(img).permute(2, 0, 1)
        label= self.labels[idx]
        return {"image": img_tensor.float(), "label": label}
    
    @staticmethod
    def _transform_idx(idx):
        idx = str(idx)
        prefix = f"trainval_{'0' * (IDX_SIZE - len(idx))}"
        postfix = '.jpg'
        full_idx = prefix + idx + postfix
        return full_idx        

In [7]:
dataset = ImageDataset(TRAIN_DIR, TRAIN_LABELS_PATH)

train_len = int(len(dataset) * TRAIN_SIZE)
val_len = len(dataset) - train_len
train_dataset, val_dataset = random_split(dataset, [train_len, val_len])

train_dataloader = DataLoader(train_dataset, batch_size=config["train_batch_size"],
                        shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=config["val_batch_size"], \
                        shuffle=True)

### Model

In [8]:
import torch.nn as nn
import torch.nn.functional as F


In [9]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 7 * 7, 496)
        self.fc2 = nn.Linear(496, 200)

    def forward(self, x, debug=False):
        x = F.relu(self.conv1(x))
        if debug:
            print(x.shape)
        x = self.pool(x)
        if debug:
            print(x.shape)
        x = self.pool(F.relu(self.conv2(x)))
        if debug:
            print(x.shape)
        x = x.view(-1, 16 * 7 * 7)
        if debug:
            print(x.shape)
        x = F.relu(self.fc1(x))
        if debug:
            print(x.shape)
        x = self.fc2(x)
        if debug:
            print(x.shape)
        return x


In [10]:
# test_img = train_dataset[0]['image'].unsqueeze(0)
# model = Model()
# model(test_img)

In [12]:
model = Model()
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
criterion = nn.CrossEntropyLoss()
model = model.to(config['device'])

3 minutes - 1500 iterations

### Train

In [None]:
wandb.init(config=config, project="dl_hse")
wandb.watch(model)

In [None]:
model.train()
best_val_accuracy = 0
for epoch in range(config['epochs_num']):
    for i, data in tqdm.tqdm(enumerate(train_dataloader)):
        # data preparation
        inputs = data["image"].to(config['device'])
        labels = data["label"].to(config['device'])

        optimizer.zero_grad()
        output = model(inputs)
        loss = criterion(output, labels)

        # compute gradients
        loss.backward()

        # make a step
        optimizer.step()

        loss = loss.item()

        if i % config['log_each'] == 0:
            val_data = next(iter(val_dataloader))
            val_inputs = val_data["image"].to(config['device'])
            val_labels = val_data["label"].to(config['device'])
            val_output = model(val_inputs)

            val_loss = criterion(val_output, val_labels)
            val_preds = torch.argmax(val_output, axis=1)
            val_accuracy = torch.sum(val_preds == val_labels).cpu().numpy() / len(val_labels)

            wandb.log({
                "Train Loss": loss, \
                "Val Loss": val_loss, \
                "Val accuracy": val_accuracy
            })
            
            if val_accuracy > best_val_accuracy:
                torch.save(model.state_dict(), )


1485it [01:08, 21.75it/s]
1485it [01:08, 21.64it/s]
1485it [01:08, 21.76it/s]
1485it [01:08, 21.65it/s]
1485it [01:08, 21.70it/s]
1485it [01:08, 21.65it/s]
718it [00:33, 22.72it/s]