In [None]:
#import warnings
#warnings.filterwarnings("ignore")

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

from kaggle_secrets import UserSecretsClient
import wandb
user_secrets = UserSecretsClient()
api_key = user_secrets.get_secret("WANDB_API_KEY")

wandb.login(key = api_key)

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning import seed_everything

import torchvision
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from pathlib import Path
from PIL import Image

import os
import random

seed_everything(4)

4

In [None]:
no_action_path = list(Path("/kaggle/input/chart-multi/0").glob('*.png'))
buy_path = list(Path("/kaggle/input/chart-multi/1").glob('*.png'))
sell_path = list(Path("/kaggle/input/chart-multi/2").glob('*.png'))

num_images = 3
no_action_samples = random.sample(no_action_path, num_images)
buy_samples = random.sample(buy_path, num_images)
sell_samples = random.sample(sell_path, num_images)

fig = plt.figure(figsize=(20,15))
images = []
labels = []

for no_action, buy, sell in zip(no_action_samples, buy_samples, sell_samples):
    with open(no_action, 'rb') as f_no_action, \
         open(buy, 'rb') as f_buy, \
         open(sell, 'rb') as f_sell:
        no_action_img, buy_img, sell_img = Image.open(f_no_action), Image.open(f_buy), Image.open(f_sell)
        no_action_img, buy_img, sell_img = no_action_img.convert("RGB"), buy_img.convert("RGB"), sell_img.convert("RGB")
        images.extend([no_action_img, buy_img, sell_img])
        labels.extend(['no-action', 'buy', 'sell'])
        

for i, (label, image) in enumerate(zip(labels, images), start=1):
    ax = plt.subplot(num_images, 3, i)
    ax.axis('off')
    ax.set_title(label)
    ax.imshow(image)
    
fig.tight_layout()
fig.show()

In [None]:
class Charts_DataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32, image_size=224, seed=42):
        super().__init__()
        self.batch_size = batch_size
        self.image_size = image_size
        self.seed = seed
        
    def prepare_data(self):
        # we might need to implement this for combining all the augmented data sets into a single folder
        pass
    
    def setup(self, stage=None):
        # Define transformations
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])       
    
        dataset_train = torchvision.datasets.ImageFolder("/kaggle/input/test-multi-2024", transform=transform)
        self.train_dataset, self.val_dataset = random_split(dataset_train, (0.8, 0.20), generator=torch.Generator().manual_seed(self.seed))
    
        dataset_test = torchvision.datasets.ImageFolder("/kaggle/input/test-multi-2024", transform=transform)
        self.test_dataset = dataset_test
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=os.cpu_count())
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=os.cpu_count())
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=os.cpu_count())
      

In [6]:
import timm  # Assuming you're using the timm library for ViT

class ViT_LightningModule(pl.LightningModule):
    def __init__(self, num_classes: int, learning_rate=2e-4, pretrained=True):
        super().__init__()
        self.model = timm.create_model('vit_base_patch16_224', pretrained=pretrained, num_classes=num_classes)
        self.learning_rate = learning_rate
        self.loss_fn = nn.CrossEntropyLoss()
        self.save_hyperparameters()

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        return optimizer

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.loss_fn(outputs, labels)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def evaluate(self, batch, stage=None):
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.loss_fn(outputs, labels)
        
        _, predicted_labels = torch.max(outputs, 1)
        correct_predictions = (predicted_labels == labels).sum().item()
        total_predictions = labels.size(0)
        acc = correct_predictions / total_predictions
        
        if stage:
            self.log(f'{stage}_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
            self.log(f'{stage}_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
            
    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")
    
    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

In [11]:
# Training the model
os.makedirs("/kaggle/working/checkpoints", exist_ok=True)

trainer = pl.Trainer(
    default_root_dir="/kaggle/working",
    accelerator="auto",
    devices=1,
    max_epochs=2,
    #strategy="ddp_notebook",
    logger = WandbLogger(name = "test"),
    log_every_n_steps=5,
    callbacks=[
        ModelCheckpoint(
            save_weights_only=True,
            mode="max",
            monitor="val_acc",
            dirpath="/kaggle/working/checkpoints",
            filename="{epoch}-{val_loss:.2f}-{val_acc:.2f}"
        ),
        EarlyStopping(
            monitor="val_acc",
            mode="max",
            patience=5,
            verbose=False
        )
    ],
)

dm = Charts_DataModule(batch_size=32, image_size=224, seed=0)
dm.setup()

checkpoints = list(Path("/kaggle/working/checkpoints").glob("*.ckpt"))
if checkpoints:
    print("Loading model from checkpoint...")
    checkpoint = max(checkpoints, key=lambda c: c.stat().st_ctime) # Grab the most recent checkpoint
    model = ViT_LightningModule.load_from_checkpoint(str(checkpoint.resolve()))
else:
    print("Training...")
    model = ViT_LightningModule(num_classes=3, pretrained=True)
    trainer.fit(model, datamodule=dm)
    # Load best checkpoint after training
    model = ViT_LightningModule.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

Loading model from checkpoint...


In [12]:
#trainer.validate(model, datamodule=dm)

/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loggers/wandb.py:389: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.


Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_loss_epoch': 0.1775885373353958, 'val_acc_epoch': 0.9651415944099426}]

In [9]:
#trainer.test(model, datamodule=dm)

Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss_epoch': 0.26019471883773804,
  'test_acc_epoch': 0.9438642263412476}]