Code: https://www.kaggle.com/code/uyiosaenabulele/garbage-classification-resnet50-96-acc, 
https://www.kaggle.com/code/arinalhaq/garbage-classification-resnet50-96-acc
data: https://www.kaggle.com/code/arinalhaq/garbage-classification-resnet50-96-acc/data

In [28]:
# notebook78590f1de4

import os
import torch
import torchvision
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from torch import nn, utils
from torchvision import transforms
import pytorch_lightning as pl
import numpy as np
import matplotlib.pyplot as plt

In [29]:
data_transforms = {
    'train': transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomResizedCrop((224, 224)),
        transforms.RandomHorizontalFlip(),
        # transforms.Grayscale(),
        # transforms.Normalize(([0.6731, 0.6398, 0.6048]), ([0.1944, 0.1931, 0.2049]))
    ]),
    'test': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
}

In [30]:
DATASET_PATH = '../input/garbage-classification/garbage_classification'
train_dataset = torchvision.datasets.ImageFolder(DATASET_PATH, transform=data_transforms['train'])
test_dataset = torchvision.datasets.ImageFolder(DATASET_PATH, transform=data_transforms['test'])

KeyboardInterrupt: 

In [None]:
LABELS = train_dataset.classes
print(train_dataset.class_to_idx)

In [None]:
torch.manual_seed(1)
np.random.seed(1)
indices = np.random.permutation(len(train_dataset)).tolist()

In [None]:
test_ratio = 0.2
test_border = len(train_dataset) - int(len(train_dataset) * (test_ratio))

train_data = torch.utils.data.Subset(train_dataset, indices[:test_border])
test_data = torch.utils.data.Subset(test_dataset, indices[test_border:])
len(train_data), len(test_data)

In [None]:
train_size = int(0.9 * len(train_data))
val_size = len(train_data) - train_size

train_data, val_data = utils.data.random_split(train_data, [train_size, val_size])
train_size, val_size

In [None]:
def show_image(image, label):
    plt.title(f"label: {label}, {LABELS[label]}\n")
    plt.imshow(transforms.ToPILImage()(image).convert('RGB'))

In [None]:
image, label = train_data[0]
show_image(image, label)

In [None]:
image, label = val_data[0]
show_image(image, label)

In [None]:
image, label = test_data[0]
show_image(image, label)

In [None]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=32, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, num_workers=2)

In [None]:
inputs, labels = next(iter(train_loader))
show_image(inputs[0], labels[0])

In [None]:
from pytorch_lightning.callbacks import Callback

class MetricMonitor(Callback):
    def __init__(self):
        self.history = []
        self.epoch = 0

    def on_train_epoch_end(self, trainer, pl_module):
        elogs = {item: float(value) for (item, value) in trainer.logged_metrics.items()}
        print(f"Epoch [{self.epoch}] train_loss: {elogs['train_loss_epoch']:.3f}, val_loss: {elogs['val_loss']:.3f}, train_acc: {elogs['train_acc']:.3f}, val_acc: {elogs['val_acc']:.3f}")
        self.epoch += 1
        self.history.append(elogs)

In [None]:
import torchvision.models as models
import torchmetrics
import torch.nn.functional as F

class LigResNet(pl.LightningModule):
    def __init__(self, lr, num_class, *args, **kwargs):
        super().__init__()
        
        self.save_hyperparameters()
        
        self.model = models.resnet50(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_class)
        
        self.train_acc = torchmetrics.Accuracy()
        self.val_acc = torchmetrics.Accuracy()
        self.test_acc = torchmetrics.Accuracy()
    
    def forward(self, x):
        return self.model(x)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5)
        return [optimizer], [scheduler]
    
    def training_step(self, batch, batch_idx):
        X, y = batch
        logits = self.model(X)
        loss = F.cross_entropy(logits, y)
        
        self.train_acc(torch.argmax(logits, dim=1), y)
        
        self.log('train_loss', loss.item(), on_epoch=True)
        self.log('train_acc', self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        X, y = batch
        logits = self.model(X)
        loss = F.cross_entropy(logits, y)
        
        self.val_acc(torch.argmax(logits, dim=1), y)
        
        self.log('val_loss', loss.item(), on_epoch=True)
        self.log('val_acc', self.val_acc, on_epoch=True)
    
    def test_step(self, batch, batch_idx):
        X, y = batch
        logits = self.model(X)
        loss = F.cross_entropy(logits, y)
        
        self.test_acc(torch.argmax(logits, dim=1), y)
        
        self.log('test_loss', loss.item(), on_epoch=True)
        self.log('test_acc', self.test_acc, on_epoch=True)
    
    def predict_step(self, batch, batch_idx):
        X, y = batch
        preds = self.model(X)
        return preds

In [None]:
class LigResNeXt(pl.LightningModule):
    def __init__(self, lr, num_class, *args, **kwargs):
        super().__init__()
        
        self.save_hyperparameters()
        
        self.model = models.resnext50_32x4d(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_class)
        
        self.train_acc = torchmetrics.Accuracy()
        self.val_acc = torchmetrics.Accuracy()
        self.test_acc = torchmetrics.Accuracy()
    
    def forward(self, x):
        return self.model(x)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5)
        return [optimizer], [scheduler]
    
    def training_step(self, batch, batch_idx):
        X, y = batch
        logits = self.model(X)
        loss = F.cross_entropy(logits, y)
        
        self.train_acc(torch.argmax(logits, dim=1), y)
        
        self.log('train_loss', loss.item(), on_epoch=True)
        self.log('train_acc', self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        X, y = batch
        logits = self.model(X)
        loss = F.cross_entropy(logits, y)
        
        self.val_acc(torch.argmax(logits, dim=1), y)
        
        self.log('val_loss', loss.item(), on_epoch=True)
        self.log('val_acc', self.val_acc, on_epoch=True)
    
    def test_step(self, batch, batch_idx):
        X, y = batch
        logits = self.model(X)
        loss = F.cross_entropy(logits, y)
        
        self.test_acc(torch.argmax(logits, dim=1), y)
        
        self.log('test_loss', loss.item(), on_epoch=True)
        self.log('test_acc', self.test_acc, on_epoch=True)
    
    def predict_step(self, batch, batch_idx):
        X, y = batch
        preds = self.model(X)
        return preds

In [None]:
num_class = len(LABELS)

model_1 = LigResNet(lr=0.00005, num_class=num_class)
model_1.model.fc

In [None]:
model_2 = LigResNeXt(lr=0.00005, num_class=num_class)
model_2.model.fc

In [None]:
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

mm = MetricMonitor()
csv_log = CSVLogger('logs', name='metric')
es = EarlyStopping('val_loss', patience=3)
mc = ModelCheckpoint(filename='{epoch}-{val_loss}', monitor='val_loss', save_top_k=3)


trainer = pl.Trainer(
    accelerator='gpu',
    limit_train_batches=100,
    max_epochs=20,
    devices=1,
    callbacks=[mm, es, mc],
    default_root_dir='./logs/resnet'
)

In [None]:
trainer.fit(model_1, train_loader, val_loader)
model = model_1

Save the model for conversion to onnx

In [None]:
torch.save(model.state_dict(), 'mnist.pth')

img_size = (640, 640)
batch_size = 1
onnx_model_path = 'model.onnx'

# model = mobilenet_v2()
# model.eval()

sample_input = torch.rand((batch_size, 3, *img_size))

y = model(sample_input)

torch.onnx.export(
    model,
    sample_input, 
    onnx_model_path,
    verbose=False,
    input_names=['input'],
    output_names=['output'],
    opset_version=12
)

Convert the model to an onnx model 

In [None]:
pip install onnx_tf

In [None]:
from onnx_tf.backend import prepare
import onnx

onnx_model_path = 'model.onnx'
tf_model_path = 'model_tf'

onnx_model = onnx.load(onnx_model_path)
tf_rep = prepare(onnx_model)
tf_rep.export_graph(tf_model_path)

Convert the onnx model to a tflite model

In [None]:
import tensorflow as tf

saved_model_dir = 'model_tf'
tflite_model_path = 'model.tflite'

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()

# Save the model
with open(tflite_model_path, 'wb') as f:
    f.write(tflite_model)

In [None]:
os.listdir('./logs/resnet/lightning_logs/version_2/checkpoints')

In [None]:
trainer.test(model_1, test_loader, ckpt_path='best')

In [None]:
trainer.test(model_1, test_loader, ckpt_path='./logs/resnet/lightning_logs/version_2/checkpoints/epoch=12-val_loss=0.16482672095298767.ckpt')

In [None]:
mm_2 = MetricMonitor()
es_2 = EarlyStopping('val_loss')
mc_2 = ModelCheckpoint(filename='{epoch}-{val_loss}', monitor='val_loss', save_top_k=3)

trainer_2 = pl.Trainer(
    accelerator='gpu',
    limit_train_batches=100,
    max_epochs=20,
    devices=1,
    callbacks=[mm_2, es_2, mc_2],
    default_root_dir='./logs/resnext'
)

In [None]:
trainer_2.fit(model_2, train_loader, val_loader)

In [None]:
os.listdir('./logs/resnext/lightning_logs/version_0/checkpoints')

In [None]:
trainer_2.test(model_2, test_loader, ckpt_path='best')

In [None]:
trainer_2.test(model_2, test_loader, ckpt_path='./logs/resnext/lightning_logs/version_0/checkpoints/epoch=8-val_loss=0.15521308779716492.ckpt')

In [None]:
import yaml

with open("./logs/resnext/lightning_logs/version_0/hparams.yaml", "r") as stream:
    try:
        print(yaml.safe_load(stream))
    except yaml.YAMLError as exc:
        print(exc)