In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from itertools import repeat
import os


class AstDataset(Dataset):
    def __init__(self, mode: str, dir: str):
        self.tensors = []
        self.labels = []
        for idx, file in enumerate(sorted(os.listdir(dir))):
            if mode in file:
                label = file.split('_')[0]
                tensor = torch.load(os.path.join(dir, file))
                self.tensors.append(tensor)
                self.labels.extend([idx//2] * tensor.shape[0])
        
        self.tensors = torch.cat(self.tensors)
        self.labels = torch.tensor(self.labels, dtype=torch.long)


    def __len__(self):
        return len(self.tensors)
    
    def __getitem__(self, idx):
        return self.tensors[idx], self.labels[idx]


train_dataset = AstDataset('train', './data/train_chunks')
test_dataset = AstDataset('test', './data/train_chunks')

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=6, persistent_workers=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=6, pin_memory=True, persistent_workers=True)


In [2]:
import torch.nn as nn
from tqdm import tqdm
from torch.amp import GradScaler, autocast
from sklearn.metrics import f1_score, accuracy_score
import numpy as np
import transformers




model_name = "google/vit-base-patch16-224-in21k"
device = 'cuda' if torch.cuda.is_available() else 'cpu'

labels = train_dataset.labels.numpy()


model = transformers.ViTForImageClassification.from_pretrained(model_name, num_labels=len(np.unique(labels)))

model.classifier = nn.Linear(model.classifier.in_features, len(np.unique(labels)))

model.to(device)

class_counts = np.bincount(labels)
class_weights = 1. / (class_counts+1e-9)
class_weights = torch.FloatTensor(class_weights).to(device)

loss_fn = nn.CrossEntropyLoss(weight=class_weights).to(device)
best_f1 = 0
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-2)
scaler = GradScaler(device)

for epoch in range(20):
    model.train()
    epoch_train_loss, epoch_test_loss = 0,0

    for data, label in tqdm(train_loader, desc='Train-Batches'):
        data, label = data.to(device), label.to(device)
        optimizer.zero_grad()
        with autocast(device):
            train_output = model(data).logits
            loss_val = loss_fn(train_output, label)
        
        epoch_train_loss += loss_val.item()
        scaler.scale(loss_val).backward()
        scaler.step(optimizer)
        scaler.update()
    
    avg_epoch_train_loss = epoch_train_loss / len(train_loader)
    print(f"Average Epoch Train Loss of {epoch+1}th Epoch: {avg_epoch_train_loss:.4f}")

    all_preds, all_labels = [], []
    model.eval()
    with torch.no_grad():
        for data, label in tqdm(test_loader, desc='Test-Batches'):
            data, label = data.to(device), label.to(device)
            with autocast(device):
                test_output = model(data).logits
                test_loss = loss_fn(test_output, label)
            epoch_test_loss += test_loss.item()
            prediction = test_output.argmax(dim=1)
            all_preds.append(prediction.cpu().numpy())
            all_labels.append(label.cpu().numpy())

    all_preds, all_labels = np.concatenate(all_preds), np.concatenate(all_labels)

    f1 = f1_score(all_labels, all_preds, average='macro')
    acc = accuracy_score(all_labels, all_preds)
    avg_epoch_test_loss = epoch_test_loss/len(test_loader)
    print(f'Average Epoch Test Loss of {epoch+1}th Epoch: {avg_epoch_test_loss:.4f}')
    print(f'Accuracy of Epoch {epoch+1}: {acc}')
    print(f'F1-Score of Epoch {epoch+1}: {f1}')

    if f1 > best_f1:
        torch.save(model.state_dict(), './bestViT.pth')

  from .autonotebook import tqdm as notebook_tqdm
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Train-Batches: 100%|██████████| 788/788 [05:04<00:00,  2.59it/s]


Average Epoch Train Loss of 1th Epoch: 4.7190


Test-Batches: 100%|██████████| 533/533 [01:16<00:00,  6.95it/s]


Average Epoch Test Loss of 1th Epoch: 5.1451
Accuracy of Epoch 1: 0.007653510058061111
F1-Score of Epoch 1: 0.006013439587123013


Train-Batches: 100%|██████████| 788/788 [05:01<00:00,  2.61it/s]


Average Epoch Train Loss of 2th Epoch: 3.9556


Test-Batches: 100%|██████████| 533/533 [01:13<00:00,  7.22it/s]


Average Epoch Test Loss of 2th Epoch: 5.1938
Accuracy of Epoch 2: 0.010409946630696146
F1-Score of Epoch 2: 0.007146739040374429


Train-Batches:  18%|█▊        | 138/788 [00:53<04:12,  2.58it/s]


KeyboardInterrupt: 

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np 
import transformers
import torch.nn as nn
from tqdm import tqdm
from torch.amp import GradScaler, autocast
from sklearn.metrics import f1_score, accuracy_score


class AstDataset(Dataset):
    def __init__(self, mode: str, data_dir: str):
        self.data_info = [] 
        self.label_to_id = {}
        current_id = 0
        

        unique_label_names = sorted(list(set([f.split('_')[0] for f in os.listdir(data_dir) if mode in f])))
        for label_name in unique_label_names:
            if label_name not in self.label_to_id:
                self.label_to_id[label_name] = current_id
                current_id += 1
        
        print(f"Discovered {len(self.label_to_id)} unique labels for {mode} mode.")

        for file_name in sorted(os.listdir(data_dir)):
            if mode in file_name:
                label_name = file_name.split('_')[0]
                label_id = self.label_to_id[label_name]
                
                file_path = os.path.join(data_dir, file_name)
                
                #
                temp_tensor_shape = torch.load(file_path).shape
                num_chunks_in_file = temp_tensor_shape[0]

                for chunk_idx in range(num_chunks_in_file):
                    self.data_info.append((file_path, chunk_idx, label_id))
        
        print(f"Loaded {len(self.data_info)} total {mode} chunks.")

    def __len__(self):
        return len(self.data_info)
    
    def __getitem__(self, idx):
        file_path, chunk_idx, label_id = self.data_info[idx]
        

        full_class_tensor = torch.load(file_path)
        
        data_chunk = full_class_tensor[chunk_idx]

        
        data_chunk = data_chunk.to(torch.float32) 
        

        
        label_tensor = torch.tensor(label_id, dtype=torch.long)
        
        return data_chunk, label_tensor




train_dataset = AstDataset('train', './data/train_chunks')
test_dataset = AstDataset('test', './data/train_chunks')

labels_for_model_init = np.array([info[2] for info in train_dataset.data_info])



train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0, pin_memory=True) 
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0, pin_memory=True) 


model_name = "google/vit-base-patch16-224-in21k"
device = 'cuda' if torch.cuda.is_available() else 'cpu'

num_classes = len(train_dataset.label_to_id)
model = transformers.ViTForImageClassification.from_pretrained(model_name, num_labels=num_classes)



model.to(device)


class_counts = np.bincount(labels_for_model_init) 
class_weights = 1. / (class_counts + 1e-9)
class_weights = torch.FloatTensor(class_weights).to(device)

loss_fn = nn.CrossEntropyLoss(weight=class_weights).to(device)
best_f1 = 0
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-2)
scaler = GradScaler() 

for epoch in range(20):
    model.train()
    epoch_train_loss, epoch_test_loss = 0,0

    for data, label in tqdm(train_loader, desc='Train-Batches'):
        data, label = data.to(device), label.to(device)
        optimizer.zero_grad()
        with autocast(dtype=torch.float16, device_type='cuda'):
            train_output = model(data).logits
            loss_val = loss_fn(train_output, label)
        
        epoch_train_loss += loss_val.item()
        scaler.scale(loss_val).backward()
        scaler.step(optimizer)
        scaler.update()
    
    avg_epoch_train_loss = epoch_train_loss / len(train_loader)
    print(f"Average Epoch Train Loss of {epoch+1}th Epoch: {avg_epoch_train_loss:.4f}")

    all_preds, all_labels = [], []
    model.eval()
    with torch.no_grad():
        for data, label in tqdm(test_loader, desc='Test-Batches'):
            data, label = data.to(device), label.to(device)
            with autocast(dtype=torch.float16, device_type='cuda'): 
                test_output = model(data).logits
                test_loss = loss_fn(test_output, label)
            epoch_test_loss += test_loss.item()
            prediction = test_output.argmax(dim=1)
            all_preds.append(prediction.cpu().numpy())
            all_labels.append(label.cpu().numpy())

    all_preds, all_labels = np.concatenate(all_preds), np.concatenate(all_labels)

    f1 = f1_score(all_labels, all_preds, average='macro')
    acc = accuracy_score(all_labels, all_preds)
    avg_epoch_test_loss = epoch_test_loss/len(test_loader)
    print(f'Average Epoch Test Loss of {epoch+1}th Epoch: {avg_epoch_test_loss:.4f}')
    print(f'Accuracy of Epoch {epoch+1}: {acc}')
    print(f'F1-Score of Epoch {epoch+1}: {f1}')

    if f1 > best_f1:
        best_f1 = f1
        torch.save(model.state_dict(), './bestViT.pth')

Discovered 98 unique labels for train mode.
Loaded 29400 total train chunks.
Discovered 98 unique labels for test mode.
Loaded 22280 total test chunks.


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Train-Batches: 100%|██████████| 460/460 [23:11<00:00,  3.02s/it]


Average Epoch Train Loss of 1th Epoch: 4.4282


Test-Batches: 100%|██████████| 349/349 [12:43<00:00,  2.19s/it]


Average Epoch Test Loss of 1th Epoch: 4.2525
Accuracy of Epoch 1: 0.19389587073608616
F1-Score of Epoch 1: 0.1524869902111614


Train-Batches: 100%|██████████| 460/460 [23:22<00:00,  3.05s/it]


Average Epoch Train Loss of 2th Epoch: 4.0401


Test-Batches: 100%|██████████| 349/349 [12:30<00:00,  2.15s/it]


Average Epoch Test Loss of 2th Epoch: 3.9321
Accuracy of Epoch 2: 0.3102333931777379
F1-Score of Epoch 2: 0.26077011920176313


Train-Batches: 100%|██████████| 460/460 [23:25<00:00,  3.05s/it]


Average Epoch Train Loss of 3th Epoch: 3.7154


Test-Batches: 100%|██████████| 349/349 [12:29<00:00,  2.15s/it]


Average Epoch Test Loss of 3th Epoch: 3.6738
Accuracy of Epoch 3: 0.37307001795332134
F1-Score of Epoch 3: 0.3285696619201988


Train-Batches: 100%|██████████| 460/460 [23:11<00:00,  3.03s/it]


Average Epoch Train Loss of 4th Epoch: 3.4199


Test-Batches: 100%|██████████| 349/349 [12:29<00:00,  2.15s/it]


Average Epoch Test Loss of 4th Epoch: 3.4386
Accuracy of Epoch 4: 0.4026032315978456
F1-Score of Epoch 4: 0.35787172807018613


Train-Batches: 100%|██████████| 460/460 [23:22<00:00,  3.05s/it]


Average Epoch Train Loss of 5th Epoch: 3.1370


Test-Batches: 100%|██████████| 349/349 [12:26<00:00,  2.14s/it]


Average Epoch Test Loss of 5th Epoch: 3.2167
Accuracy of Epoch 5: 0.43994614003590665
F1-Score of Epoch 5: 0.39969874424529733


Train-Batches: 100%|██████████| 460/460 [23:30<00:00,  3.07s/it]


Average Epoch Train Loss of 6th Epoch: 2.8628


Test-Batches: 100%|██████████| 349/349 [12:33<00:00,  2.16s/it]


Average Epoch Test Loss of 6th Epoch: 3.0262
Accuracy of Epoch 6: 0.46247755834829446
F1-Score of Epoch 6: 0.4284819388904043


Train-Batches: 100%|██████████| 460/460 [23:31<00:00,  3.07s/it]


Average Epoch Train Loss of 7th Epoch: 2.5941


Test-Batches: 100%|██████████| 349/349 [12:31<00:00,  2.15s/it]


Average Epoch Test Loss of 7th Epoch: 2.8467
Accuracy of Epoch 7: 0.47630161579892283
F1-Score of Epoch 7: 0.4432394514856658


Train-Batches: 100%|██████████| 460/460 [23:23<00:00,  3.05s/it]


Average Epoch Train Loss of 8th Epoch: 2.3400


Test-Batches: 100%|██████████| 349/349 [12:36<00:00,  2.17s/it]


Average Epoch Test Loss of 8th Epoch: 2.7131
Accuracy of Epoch 8: 0.4902603231597846
F1-Score of Epoch 8: 0.4639827462455889


Train-Batches: 100%|██████████| 460/460 [23:51<00:00,  3.11s/it]


Average Epoch Train Loss of 9th Epoch: 2.0943


Test-Batches: 100%|██████████| 349/349 [13:09<00:00,  2.26s/it]


Average Epoch Test Loss of 9th Epoch: 2.5667
Accuracy of Epoch 9: 0.5052513464991023
F1-Score of Epoch 9: 0.47879413621833977


Train-Batches: 100%|██████████| 460/460 [24:06<00:00,  3.14s/it]


Average Epoch Train Loss of 10th Epoch: 1.8548


Test-Batches: 100%|██████████| 349/349 [13:01<00:00,  2.24s/it]


Average Epoch Test Loss of 10th Epoch: 2.4718
Accuracy of Epoch 10: 0.5140035906642729
F1-Score of Epoch 10: 0.4909868191890801


Train-Batches: 100%|██████████| 460/460 [24:22<00:00,  3.18s/it]


Average Epoch Train Loss of 11th Epoch: 1.6271


Test-Batches: 100%|██████████| 349/349 [12:55<00:00,  2.22s/it]


Average Epoch Test Loss of 11th Epoch: 2.3782
Accuracy of Epoch 11: 0.5175942549371634
F1-Score of Epoch 11: 0.495018972724812


Train-Batches:   0%|          | 2/460 [00:07<28:01,  3.67s/it]


KeyboardInterrupt: 