In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
import os

from MMClassifyFunc.train import Trainer
from MMClassifyFunc.models import CustomResNet
from MMClassifyFunc.data_preprocess import get_loader
from MMClassifyFunc.data_read import get_data_png
from MMClassifyFunc.visualization import visualize_results

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

folder_path = r'/home/mambauser/MMCode/data/processed1D'
in_channels = 3

samples, labels = get_data_png(
    folder_path=folder_path,
    in_channels=in_channels,
    # wordIndex=list(range(5)),
    fileIndex=list(range(0,10))+list(range(12,30))+list(range(32,40)),
    # fileIndex=list(range(0,10))+list(range(30,40)),
    # fileIndex=list(range(0,40)),
    # personIndex=[1],
    txIndex=[0,1],
)

print("len(samples): {}".format(len(samples)))
print("len(set(labels)): {}".format(len(set(labels))))

trainloader, testloader = get_loader(samples=samples, labels=labels)

# classifier
classifier = CustomResNet(in_channels=in_channels,
                          num_classes=len(set(labels)),
                          weights=models.ResNet18_Weights.DEFAULT,
                        #   weights=None,
                          model='resnet18')

# optimizers
lr = 1e-3
betas = (.5, .99)
optimizer = optim.Adam(classifier.parameters(), lr=lr, betas=betas)
criterion = nn.CrossEntropyLoss()

# train model
NUM_INPUTS = 1
epochs = 30

trainer = Trainer(
    num_inputs=NUM_INPUTS,
    classifier=classifier,
    optimizer=optimizer,
    criterion=criterion,
    print_every=1,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    use_cuda=torch.cuda.is_available(),
    use_scheduler=False)

trainer.train(trainloader=trainloader, testloader=testloader, epochs=epochs)

visualize_results(trainer=trainer)

# torch.save(classifier.state_dict(), '/home/mambauser/MMCode/data/model1d.pth')


In [None]:
import torch
from MMClassifyFunc.models import CustomResNet
from MMClassifyFunc.data_read import get_data_png
from tqdm import tqdm

from MMClassifyFunc.data_preprocess import get_loader_all
from MMClassifyFunc.visualization import visualize_predict

# Load data
folder_path = r'/home/mambauser/MMCode/data/processed1D'
in_channels = 3

samples, labels = get_data_png(
    folder_path=folder_path,
    in_channels=in_channels,
    # wordIndex=list(range(5)),
    fileIndex=[10,11,30,31],
    # personIndex=[1],
    # txIndex=[0,4,8],
)

print("len(samples): {}".format(len(samples)))
print("len(set(labels)): {}".format(len(set(labels))))

# Create dataset and dataloader
dataloader = get_loader_all(samples, labels)

# Load model
# model_path = '/home/mambauser/MMCode/data/model1d.pth'  # Path to your saved model
# classifier = CustomResNet(in_channels=in_channels,
#                           num_classes=len(set(labels)),
#                           model='resnet18')
# classifier.load_state_dict(torch.load(model_path,weights_only=True))
# classifier.eval()
trainer.classifier.eval()

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# classifier.to(device)
trainer.classifier.to(device)

# Prepare for evaluation
all_labels = []
all_preds = []

with torch.no_grad():
    for images, labels in tqdm(dataloader, desc='Processing batches'):
        images, labels = images.to(device), labels.to(device)
        outputs = trainer.classifier(images)
        _, predicted = torch.max(outputs.data, 1)

        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(predicted.cpu().numpy())

visualize_predict(all_labels, all_preds)
