In [None]:
import os
import sys
import glob
import numpy as np
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
from sklearn.model_selection import train_test_split

project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
src_path = os.path.join(project_root, 'src')
sys.path.insert(0, src_path)


from train import validate_model, train_model, plot_training_history
from models import CTClassifier
from datasets import SliceDataset

In [None]:
batch_size = 64
num_workers = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 10
random_state = 42

In [None]:
normal_dirs = [
    '/kaggle/input/ct-images/CT/normal',
    '/kaggle/input/iqothnccd-lung-cancer-dataset/The IQ-OTHNCCD lung cancer dataset/The IQ-OTHNCCD lung cancer dataset/Normal cases',
    '/kaggle/input/large-covid19-ct-slice-dataset/curated_data/curated_data/1NonCOVID',
]

pathology_dirs = [
    '/kaggle/input/ct-images/CT/cancer',
    '/kaggle/input/iqothnccd-lung-cancer-dataset/The IQ-OTHNCCD lung cancer dataset/The IQ-OTHNCCD lung cancer dataset/Bengin cases',
    '/kaggle/input/iqothnccd-lung-cancer-dataset/The IQ-OTHNCCD lung cancer dataset/The IQ-OTHNCCD lung cancer dataset/Malignant cases',
    '/kaggle/input/large-covid19-ct-slice-dataset/curated_data/curated_data/2COVID',
    '/kaggle/input/large-covid19-ct-slice-dataset/curated_data/curated_data/3CAP' #???
]

paths = []
labels = []

for pathology_path in pathology_dirs:
    all_image_paths = glob.glob(os.path.join(pathology_path, '*'))
    paths.extend(all_image_paths)
    labels.extend([1.0] * len(all_image_paths))

for normal_path in normal_dirs:
    all_image_paths = glob.glob(os.path.join(normal_path, '*'))
    paths.extend(all_image_paths)
    labels.extend([0.0] * len(all_image_paths))

print(f"Всего образцов: {len(paths)}")
print(f"Норма (CT-0): {labels.count(0)}")
print(f"Патология: {labels.count(1)}")

# Разбиение
train_samples, temp_samples, train_labels, temp_labels = train_test_split(
    paths, labels,
    test_size=0.3,
    random_state=random_state,
    stratify=labels,
    shuffle=True
)

val_samples, test_samples, val_labels, test_labels = train_test_split(
    temp_samples, temp_labels,
    test_size=0.5,
    random_state=random_state,
    stratify=temp_labels,
    shuffle=True
)

print(f"\nTrain: {len(train_samples)} ({len(train_samples)/len(paths):.1%})")
print(f"Val:   {len(val_samples)} ({len(val_samples)/len(paths):.1%})")
print(f"Test:  {len(test_samples)} ({len(test_samples)/len(paths):.1%})")

In [None]:
train_dataset = SliceDataset(train_samples, train_labels)
val_dataset = SliceDataset(val_samples, val_labels)
test_dataset = SliceDataset(test_samples, test_labels)

train_labels_array = np.array(train_labels)
num_neg = (train_labels_array == 0).sum()
num_pos = (train_labels_array == 1).sum()

print(f"Train balance — норма: {num_neg}, патология: {num_pos} (ratio pos/neg = {num_pos/num_neg:.2f})")

class_weights = 1.0 / np.array([num_neg, num_pos])
sample_weights = class_weights[train_labels_array.astype(int)]

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    sampler=sampler,
    num_workers=num_workers,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=num_workers, 
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=num_workers, 
    pin_memory=True
)

In [None]:
model = CTClassifier(embed_dim=128).to(device)

train_history = train_model(
    model,
    train_loader,
    val_loader,
    device,
    epochs,
    lr=1e-4,
    wd=1e-5,
    patience=2,
    save_path='best_slice_clf_model.pth'
)

plot_training_history(train_history)

In [None]:
validate_model(model, test_loader, device)