In [31]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD
from torch.utils.data import Dataset, ConcatDataset, DataLoader
from torchvision.transforms import ToTensor, Normalize, Resize, Compose
from utils.misc import print_h5py, elapsed_time
import h5py
from utils.confusion_matrix import ConfusionMatrix
import importlib
from datetime import datetime
from argparse import ArgumentParser
import wandb


In [2]:
paths = {'ori':'/shared_hdd/sin/data.h5py',
         'gen':'/shared_hdd/sin/gen.h5py'}

ids=["1r08m7xe",
"3rczol1e",
"3u337fao",
"6gem8lca",
"3p2vbr49",
"i2mijyme",
"iitr6wnc",
"vpgfgriu"]

data_names=["CIFAR10_LT",
"CIFAR10_LT",
"CIFAR10_LT",
"CIFAR10_LT",
"FashionMNIST_LT",
"FashionMNIST_LT",
"FashionMNIST_LT",
"FashionMNIST_LT"]

gen_query = lambda id, data_name : f'gen/{id}/{data_name}'
ori_query = lambda data_name : f'ori/{data_name}'

In [166]:
with h5py.File('/shared_hdd/sin/gen.h5py', 'r') as file:
    print_h5py(file)

with h5py.File('/shared_hdd/sin/data.h5py', 'r') as file:
    print_h5py(file)

gen
	1r08m7xe
		CIFAR10_LT
			data
			(40500, 64, 64, 3)
			targets
			(40500,)
	3p2vbr49
		FashionMNIST_LT
			data
			(49500, 64, 64, 3)
			targets
			(49500,)
	3rczol1e
		CIFAR10_LT
			data
			(40500, 64, 64, 3)
			targets
			(40500,)
	3u337fao
		CIFAR10_LT
			data
			(40500, 64, 64, 3)
			targets
			(40500,)
	6gem8lca
		CIFAR10_LT
			data
			(40500, 64, 64, 3)
			targets
			(40500,)
	i2mijyme
		FashionMNIST_LT
			data
			(49500, 64, 64, 3)
			targets
			(49500,)
	iitr6wnc
		FashionMNIST_LT
			data
			(49500, 64, 64, 3)
			targets
			(49500,)
	vpgfgriu
		FashionMNIST_LT
			data
			(49500, 64, 64, 3)
			targets
			(49500,)
ori
	CIFAR10_LT
		test
			data
			(10000, 32, 32, 3)
			targets
			(10000,)
		train
			data
			(9500, 32, 32, 3)
			targets
			(9500,)
	FashionMNIST_LT
		test
			data
			(10000, 28, 28)
			targets
			(10000,)
		train
			data
			(10500, 28, 28)
			targets
			(10500,)
	Places_LT
		test
			data
			(36500, 256, 256, 3)
			targets
			(36500,)
		train
			data
			(62500, 2

In [18]:
class H5py_dataset(Dataset):
    def __init__(self, path, query, transforms=None):
        file = h5py.File(path, 'r')        
        self.data = file[f'{query}/data'].astype(np.uint8)
        self.targets = file[f'{query}/targets'].astype(np.uint8)
        self.transforms = transforms
        self.classes, self.num_classes = np.unique(self.targets, return_counts=True)
        self.num_class = len(self.classes)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img, label = self.data[idx], self.targets[idx]
        
        if img.ndim == 2:
            img = img[..., None].repeat(3, 2)
        
        if self.transforms is not None:
            img = self.transforms(img)
        return img, label

transforms = Compose([ToTensor(),
                    Resize(64),
                    Normalize(mean=[0.5, 0.5, 0.5],
                                std=[0.5, 0.5, 0.5])])

gen_dataset = H5py_dataset(path=paths['gen'], query=f'gen/{ids[args.data_seq]}/{data_names[args.data_seq]}', transforms=transforms)
ori_dataset = H5py_dataset(path=paths['ori'], query=f'ori/{data_names[args.data_seq]}/train', transforms=transforms)

train_dataset = ConcatDataset([ori_dataset, gen_dataset])
test_dataset = H5py_dataset(path=paths['ori'], query=f'ori/{data_names[args.data_seq]}/test', transforms=transforms)

train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)


In [58]:
device = torch.device(args.gpus)
# device = args.gpus

module_model = importlib.import_module('torchvision.models')
model = getattr(module_model, args.model)(num_classes=ori_dataset.num_class).to(device)

optimizer = SGD(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

cm_train = ConfusionMatrix(ori_dataset.num_class)
cm_test = ConfusionMatrix(ori_dataset.num_class)
global_epoch = 0
losses_train = 0
losses_test = 0

best_acc = 0

model.train()
start_time = datetime.now()
train_iter = iter(train_loader)
for step in range(args.steps):
    try:
        img, label = next(train_iter)
    except StopIteration:
        global_epoch += 1
        train_iter = iter(train_loader)
        img, label = next(train_iter)
    img, label = img.to(device), label.to(device)
    
    optimizer.zero_grad()
    y_hat = model(img)
    loss = loss_fn(y_hat, label)
    loss.backward()
    optimizer.step()
    
    pred = F.softmax(y_hat, 1).argmax(1)
    losses_train += loss.item()
    cm_train.update(label.cpu().numpy(), pred.cpu().numpy())
    
    if (step+1) % 100 == 0:
        model.eval()
        for i, (img, label) in enumerate(test_loader):
            img, label = img.to(device), label.to(device) 

            y_hat = model(img)
            loss = loss_fn(y_hat, label)

            pred = F.softmax(y_hat, 1).argmax(1)
            losses_test += loss.item()
            cm_test.update(label.cpu().numpy(), pred.cpu().numpy())
        
        
        losses_train = losses_train / 100
        losses_test = losses_test / len(test_loader)
        acc_train = cm_train.getAccuracy()
        acc_test = cm_test.getAccuracy()
        acc_train_cls = cm_train.getAccuracyPerClass()
        acc_test_cls = cm_test.getAccuracyPerClass()
        
        acc_train_cls = {f'acc/train/{idx}': acc for idx, acc in enumerate(acc_train_cls)}
        acc_test_cls = {f'acc/test/{idx}': acc for idx, acc in enumerate(acc_test_cls)}
        
        if best_acc < acc_test:
            best_acc = acc_test
            wandb.summary['best/acc'] = acc_test
            wandb.summary['best/step'] = step
            for idx, acc in enumerate(acc_test_cls):
                wandb.summary[f'best/acc/{idx}'] = acc
            
            
        
        wandb.log({'loss/train': losses_train,
                   'loss/test': losses_test,
                   'acc/train': acc_train,
                   'acc/test': acc_test}, step=(step+1))
        
        wandb.log(acc_train_cls, step=(step+1))
        wandb.log(acc_test_cls, step=(step+1))
        
        
        print(f'step: {step+1}/{args.steps}({((step+1) / args.steps)*100:.2f}%), '
                    f'time: {elapsed_time(start_time)}, '
                    f'loss train: {losses_train}, '
                    f'loss test: {losses_test}, '
                    f'acc train: {acc_train}, '
                    f'acc test: {acc_test}, '
                    f'acc best: {best_acc}')
        
        model.train()
        losses_train = 0
        losses_test = 0
        cm_train.reset()
        cm_test.reset()
        
        
        
        # for idx, (acc_train, acc_test) in enumerate(zip(acc_train_cls, acc_test_cls)):
        #     print(f'cls {idx}: {acc_train}, {acc_test}')
        
        

step: 100/1000(10.00%), time: 0:00:18, loss train: 1.8067075848579406, loss test: 2.2844077575055857, acc train: 0.571796875, acc test: 0.1667, acc best: 0.1667
step: 200/1000(20.00%), time: 0:00:37, loss train: 1.1909965401887894, loss test: 2.2423489652102506, acc train: 0.85421875, acc test: 0.1903, acc best: 0.1903
step: 300/1000(30.00%), time: 0:00:56, loss train: 0.8435622948408127, loss test: 2.2512957189656513, acc train: 0.8728125, acc test: 0.1971, acc best: 0.1971
step: 400/1000(40.00%), time: 0:01:15, loss train: 0.6620379745960235, loss test: 2.2939400348482253, acc train: 0.8805677540777918, acc test: 0.1983, acc best: 0.1983
step: 500/1000(50.00%), time: 0:01:34, loss train: 0.5553600159287453, loss test: 2.285524070262909, acc train: 0.883125, acc test: 0.2046, acc best: 0.2046
step: 600/1000(60.00%), time: 0:01:52, loss train: 0.4775961628556252, loss test: 2.3138351817674274, acc train: 0.8946875, acc test: 0.1963, acc best: 0.2046
step: 700/1000(70.00%), time: 0:02:1

In [57]:
parser = ArgumentParser()
parser.add_argument("--model", type=str, default='resnet18', required=False)
parser.add_argument("--gpus", type=int, default=1, required=False)
parser.add_argument("--data_seq", type=int, default=0, required=False)
parser.add_argument("--steps", type=int, default=1000, required=False)

parser.add_argument("--lr", type=float, default=0.001, required=False)
parser.add_argument("--batch_size", type=int, default=128, required=False)


args = parser.parse_args([])

wandb.init(project="cls_eval", entity="sinaenjuni", config=args)




Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msinaenjuni[0m. Use [1m`wandb login --relogin`[0m to force relogin
