In [8]:
import sys
sys.path.append('/Users/parth/Desktop/Flash/')

In [9]:
import torch 
import torch.nn as nn
import numpy as np
import pandas as pd 
import neptune

from sklearn import metrics
from sklearn.model_selection import train_test_split

import timm
import albumentations
from neptune.types import File
from torchret import Model 
import plotly.graph_objects as go

In [10]:
configs = {
    'lr' : 1e-4,
    'eta_min' : 1e-6,
    'T_0' : 20,
    'epochs' : 20,
    'step_scheduler_after' : 'epoch',

    'train_bs' : 256,
    'valid_bs' : 256,

    'num_workers' : 0,
    'pin_memory' : False,

    'model_name' : 'resnet10t',
    'pretrained' : True,
    'num_classes' : 10,
    'in_channels' : 1,
    'device' : 'mps',

    'model_path' : 'digit-recognizer.pt',
    'save_best_model' : 'on_eval_metric',
    'save_on_metric' : 'accuracy',
    'save_model_at_every_epoch' : False,

}

In [11]:
class DigitRecognizerDataset(torch.utils.data.Dataset):
    def __init__(self, df, augmentations):
        self.df = df
        self.targets = df.label.values
        self.df = self.df.drop(columns=["label"])
        self.augmentations = augmentations

        self.images = self.df.to_numpy(dtype=np.float32).reshape((-1, 28, 28))

    def __len__(self):
        return len(self.df)

    def __getitem__(self, item):
        targets = self.targets[item]
        image = self.images[item]
        image = np.expand_dims(image, axis=0)

        return {
            "images": torch.tensor(image, dtype=torch.float),
            "targets": torch.tensor(targets, dtype=torch.long),
        }

In [12]:
train_augs = albumentations.Compose(
        [
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
                max_pixel_value=255.0,
                p=1.0,
            ),
        ],
        p=1.0,
    )

valid_augs = albumentations.Compose(
        [
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
                max_pixel_value=255.0,
                p=1.0,
            ),
        ],
        p=1.0,
    )

In [13]:
class DigitRecognizerModel(Model):
    def __init__(self):
        super().__init__()

        self.model = timm.create_model(
            model_name = configs['model_name'],
            pretrained=configs['pretrained'],
            in_chans=configs['in_channels'],
            num_classes=configs['num_classes'],
        )

        self.num_workers = configs['num_workers']
        self.pin_memory = configs['pin_memory']
        self.step_scheduler_after = configs['step_scheduler_after']

        self.model_path = configs['model_path']
        self.save_best_model = configs['save_best_model']
        self.save_on_metric = configs['save_on_metric']
        self.save_model_at_every_epoch = configs['save_model_at_every_epoch']

    def setup_logger(self):
        neptune_api = NEPTUNE_API_TOKEN
        self.run = neptune.init_run(
            project='fenilsavani62/Digit-recog',
            api_token=neptune_api,
            capture_stdout=True,       # Enable capture of stdout
            capture_stderr=True,       # Enable capture of stderr
            capture_traceback=True,    # Enable capture of traceback
            capture_hardware_metrics=True,  # Enable capture of hardware metrics
            source_files='*.ipynb'  # Capture notebook outputs
        )
        self.run['parameters'] = configs

    def valid_one_step_logs(self, batch_id, data, logits, loss, metrics):
        if batch_id % len(self.validloader) == 100:
            images = data['images']
            labels = data['targets']
            outputs = np.argmax(logits.cpu().detach().numpy(), axis=1)
            images = images.permute(0, 2, 3, 1).squeeze().cpu()
            for i in range(len(images)):
                description = f'true label : {labels[i]} prediction : {outputs[i]}'
                self.run["valid/prediction_example"].append(File.as_image(images[i]), description = description)

    def train_one_epoch_logs(self, loss, monitor):
        self.run['train/loss'].append(loss)
        self.run['train/monitors'].append(monitor)

    def valid_one_epoch_logs(self, loss, monitor):
        self.run['valid/loss'].append(loss)
        self.run['valid/monitors'].append(monitor)

    def monitor_metrics(self, outputs, targets):
        device = targets.device.type
        outputs = np.argmax(outputs.cpu().detach().numpy(), axis=1)
        targets = targets.cpu().detach().numpy()
        acc = metrics.accuracy_score(targets, outputs)
        acc = torch.tensor(acc).float()
        f1_score = metrics.f1_score(targets, outputs, average = 'macro')
        f1_score = torch.tensor(f1_score)
        return {"accuracy": acc, 'f1_score' : f1_score}
    
    def monitor_loss(self, outputs, targets):
        loss = nn.CrossEntropyLoss()(outputs, targets)
        return loss

    def fetch_optimizer(self):
        opt = torch.optim.SGD(
            self.parameters(),
            lr=configs['lr'],
            momentum=0.9,
        )
        sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            opt, 
            T_0 = configs['T_0'],
            eta_min = configs['eta_min'],
            T_mult= 1
        )
        return opt, sch

    def forward(self, images, targets=None):
        x = self.model(images)
        if targets is not None:
            loss = self.monitor_loss(x, targets)
            metrics = self.monitor_metrics(x, targets)
            return x, loss, metrics
        return x, 0, {}

In [14]:
def main():
    df = pd.read_csv('train.csv')
    train, test = train_test_split(df, test_size=0.2)

    train_dataset = DigitRecognizerDataset(df = train, augmentations = train_augs)
    valid_dataset = DigitRecognizerDataset(df = test, augmentations = valid_augs)

    model = DigitRecognizerModel()
    model.fit(train_dataset, valid_dataset, train_bs = configs['train_bs'], valid_bs = configs['valid_bs'], device = configs['device'], epochs = configs['epochs'], logger=False)

if __name__ == "__main__":
    main()

100%|██████████| 132/132 [00:20<00:00,  6.56it/s, accuracy=0.282, current_lr=0.0001, epoch=1, f1_score=0.274, loss=2.096917, stage=train]
100%|██████████| 33/33 [00:01<00:00, 19.47it/s, accuracy=0.472, epoch=1, f1_score=0.456, loss=1.647311, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.4721008159897544 accuracy


100%|██████████| 132/132 [00:21<00:00,  6.05it/s, accuracy=0.584, current_lr=0.0001, epoch=2, f1_score=0.571, loss=1.366865, stage=train]
100%|██████████| 33/33 [00:01<00:00, 19.51it/s, accuracy=0.674, epoch=2, f1_score=0.658, loss=1.140475, stage=eval]


Model Saved at digit-recognizer.pt
Model was saved based on_eval_metric with 0.674269741231745 accuracy


 97%|█████████▋| 128/132 [00:19<00:00,  6.40it/s, accuracy=0.718, current_lr=0.0001, epoch=3, f1_score=0.709, loss=0.985603, stage=train]


KeyboardInterrupt: 