Avant, on initialisait les nodes features à 0 (avec la fonction add_zeros), ce qui rends les noeuds indifférentiables. Afin de leur donner de l'information, on modifie tout ça, en permettant de mettre des features aux nodes (par exemple leur degré).

# Change the GNN's

In [1]:
!pip install torch_geometric torch --quiet
import torch
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool, global_add_pool
from torch_geometric.utils import degree

import math

### GIN convolution along the graph structure
class GINConv(MessagePassing):
    def __init__(self, emb_dim):
        '''
            emb_dim (int): node embedding dimensionality
        '''

        super(GINConv, self).__init__(aggr = "add")

        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
        self.eps = torch.nn.Parameter(torch.Tensor([0]))

        self.edge_encoder = torch.nn.Linear(7, emb_dim)

    def forward(self, x, edge_index, edge_attr):
        edge_embedding = self.edge_encoder(edge_attr)
        out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))

        return out

    def message(self, x_j, edge_attr):
        return F.relu(x_j + edge_attr)

    def update(self, aggr_out):
        return aggr_out

### GCN convolution along the graph structure
class GCNConv(MessagePassing):
    def __init__(self, emb_dim):
        super(GCNConv, self).__init__(aggr='add')

        self.linear = torch.nn.Linear(emb_dim, emb_dim)
        self.root_emb = torch.nn.Embedding(1, emb_dim)
        self.edge_encoder = torch.nn.Linear(7, emb_dim)

    def forward(self, x, edge_index, edge_attr):
        x = self.linear(x)
        edge_embedding = self.edge_encoder(edge_attr)

        row, col = edge_index

        #edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device)
        deg = degree(row, x.size(0), dtype = x.dtype) + 1
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        return self.propagate(edge_index, x=x, edge_attr = edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1)

    def message(self, x_j, edge_attr, norm):
        return norm.view(-1, 1) * F.relu(x_j + edge_attr)

    def update(self, aggr_out):
        return aggr_out


### GNN to generate node embedding
class GNN_node(torch.nn.Module):
    """
    Output:
        node representations
    """
    def __init__(self, num_layer, emb_dim, in_features=1, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'):
        '''
            emb_dim (int): node embedding dimensionality
            num_layer (int): number of GNN message passing layers

        '''

        super(GNN_node, self).__init__()
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        ### add residual connection or not
        self.residual = residual

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.node_encoder = torch.nn.Linear(in_features, emb_dim) # uniform input node embedding

        ###List of GNNs
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        for layer in range(num_layer):
            if gnn_type == 'gin':
                self.convs.append(GINConv(emb_dim))
            elif gnn_type == 'gcn':
                self.convs.append(GCNConv(emb_dim))
            else:
                raise ValueError('Undefined GNN type called {}'.format(gnn_type))

            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))

    def forward(self, batched_data):
        x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch


        ### computing input node embedding

        h_list = [self.node_encoder(x)]  # x shape [num_nodes, in_features]
        for layer in range(self.num_layer):

            h = self.convs[layer](h_list[layer], edge_index, edge_attr)
            h = self.batch_norms[layer](h)

            if layer == self.num_layer - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training = self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)

            if self.residual:
                h += h_list[layer]

            h_list.append(h)

        ### Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layer + 1):
                node_representation += h_list[layer]

        return node_representation


### Virtual GNN to generate node embedding
class GNN_node_Virtualnode(torch.nn.Module):
    """
    Output:
        node representations
    """
    def __init__(self, num_layer, emb_dim, in_features=1, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'):
        '''
            emb_dim (int): node embedding dimensionality
        '''

        super(GNN_node_Virtualnode, self).__init__()
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        ### add residual connection or not
        self.residual = residual

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.node_encoder = torch.nn.Linear(in_features, emb_dim) # uniform input node embedding

        ### set the initial virtual node embedding to 0.
        self.virtualnode_embedding = torch.nn.Embedding(1, emb_dim)
        torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)

        ### List of GNNs
        self.convs = torch.nn.ModuleList()
        ### batch norms applied to node embeddings
        self.batch_norms = torch.nn.ModuleList()

        ### List of MLPs to transform virtual node at every layer
        self.mlp_virtualnode_list = torch.nn.ModuleList()

        for layer in range(num_layer):
            if gnn_type == 'gin':
                self.convs.append(GINConv(emb_dim))
            elif gnn_type == 'gcn':
                self.convs.append(GCNConv(emb_dim))
            else:
                raise ValueError('Undefined GNN type called {}'.format(gnn_type))

            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))

        for layer in range(num_layer - 1):
            self.mlp_virtualnode_list.append(torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), \
                                                    torch.nn.Linear(2*emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim), torch.nn.ReLU()))


    def forward(self, batched_data):

        x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch

        ### virtual node embeddings for graphs
        virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device))

        h_list = [self.node_encoder(x)]  # x shape [num_nodes, in_features]

        for layer in range(self.num_layer):
            ### add message from virtual nodes to graph nodes
            h_list[layer] = h_list[layer] + virtualnode_embedding[batch]

            ### Message passing among graph nodes
            h = self.convs[layer](h_list[layer], edge_index, edge_attr)

            h = self.batch_norms[layer](h)
            if layer == self.num_layer - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training = self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)

            if self.residual:
                h = h + h_list[layer]

            h_list.append(h)

            ### update the virtual nodes
            if layer < self.num_layer - 1:
                ### add message from graph nodes to virtual nodes
                virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding
                ### transform virtual nodes using MLP

                if self.residual:
                    virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)
                else:
                    virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.drop_ratio, training = self.training)

        ### Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layer + 1):
                node_representation += h_list[layer]

        return node_representation


In [2]:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
import torch.nn.functional as F
from torch_geometric.nn.inits import uniform

class GNN(torch.nn.Module):

    def __init__(self, num_class, num_layer = 5, emb_dim = 300, 
                    gnn_type = 'gin', virtual_node = True, residual = False, drop_ratio = 0.5, JK = "last", graph_pooling = "mean"):
        '''
            num_tasks (int): number of labels to be predicted
            virtual_node (bool): whether to add virtual node or not
        '''

        super(GNN, self).__init__()

        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_class = num_class
        self.graph_pooling = graph_pooling

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        ### GNN to generate node embeddings
        if virtual_node:
            self.gnn_node = GNN_node_Virtualnode(num_layer, emb_dim, in_features=1, JK=JK, drop_ratio=drop_ratio, residual=residual, gnn_type=gnn_type)

        else:
            self.gnn_node = GNN_node(num_layer, emb_dim, in_features=1, JK=JK, drop_ratio=drop_ratio, residual=residual, gnn_type=gnn_type)


        ### Pooling function to generate whole-graph embeddings
        if self.graph_pooling == "sum":
            self.pool = global_add_pool
        elif self.graph_pooling == "mean":
            self.pool = global_mean_pool
        elif self.graph_pooling == "max":
            self.pool = global_max_pool
        elif self.graph_pooling == "attention":
            self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1)))
        elif self.graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, processing_steps = 2)
        else:
            raise ValueError("Invalid graph pooling type.")

        if graph_pooling == "set2set":
            self.graph_pred_linear = torch.nn.Linear(2*self.emb_dim, self.num_class)
        else:
            self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_class)

    def forward(self, batched_data):
        h_node = self.gnn_node(batched_data)

        h_graph = self.pool(h_node, batched_data.batch)

        return self.graph_pred_linear(h_graph)

In [3]:
%cd hackaton/

/home/onyxia/work/DL-Hackathon/hackaton


# Check the changements

In [4]:
from torch_geometric.utils import degree

def add_degree_feature(data):
    row, col = data.edge_index
    deg = degree(row, data.num_nodes, dtype=torch.float)
    data.x = deg.view(-1, 1)  # tensor shape [num_nodes, 1]
    return data

In [5]:
from src.loadData import GraphDataset
    
full_dataset = GraphDataset('datasets/A/train.json.gz', transform=add_degree_feature)

In [6]:
full_dataset

GraphDataset(11280)

In [7]:
full_dataset[0]

Data(edge_index=[2, 3544], edge_attr=[3544, 7], y=[1], num_nodes=300, x=[300, 1])

In [8]:
full_dataset[0].x

tensor([[ 6.],
        [ 6.],
        [ 5.],
        [ 7.],
        [ 6.],
        [ 9.],
        [ 5.],
        [25.],
        [60.],
        [ 7.],
        [ 1.],
        [ 8.],
        [10.],
        [18.],
        [19.],
        [ 8.],
        [ 7.],
        [ 7.],
        [28.],
        [16.],
        [23.],
        [ 4.],
        [23.],
        [33.],
        [21.],
        [46.],
        [14.],
        [12.],
        [23.],
        [ 5.],
        [16.],
        [17.],
        [ 6.],
        [ 9.],
        [18.],
        [62.],
        [ 6.],
        [35.],
        [16.],
        [12.],
        [13.],
        [76.],
        [20.],
        [40.],
        [10.],
        [48.],
        [38.],
        [19.],
        [40.],
        [22.],
        [12.],
        [ 5.],
        [18.],
        [18.],
        [10.],
        [19.],
        [ 9.],
        [14.],
        [22.],
        [ 8.],
        [17.],
        [44.],
        [10.],
        [10.],
        [39.],
        [27.],
        [4

# Try the changement on co teaching

Maintenant, on essaie co teaching avec ces noeuds qui n'ont plus des 0 comme feature, mais leur degré ! 

In [9]:
import os
import torch
import pandas as pd
import matplotlib.pyplot as plt
import logging
from tqdm import tqdm
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
# Load utility functions from cloned repository
from src.loadData import GraphDataset
from src.utils import set_seed
#from src.models import GNN (We do not need it, as we changed it in the above cells)
import argparse

# Set the random seed
set_seed(42)


In [10]:
def add_degree_feature(data):
    row, col = data.edge_index
    deg = degree(row, data.num_nodes, dtype=torch.float)
    data.x = deg.view(-1, 1)  # tensor shape [num_nodes, 1]
    return data

In [11]:
def train(data_loader, model, optimizer, criterion, device, save_checkpoints, checkpoint_path, current_epoch):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    for data in tqdm(data_loader, desc="Iterating training graphs", unit="batch"):
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += (pred == data.y).sum().item()
        total += data.y.size(0)

    # Save checkpoints if required
    if save_checkpoints:
        checkpoint_file = f"{checkpoint_path}_epoch_{current_epoch + 1}.pth"
        torch.save(model.state_dict(), checkpoint_file)
        print(f"Checkpoint saved at {checkpoint_file}")

    return total_loss / len(data_loader),  correct / total

In [12]:
def evaluate(data_loader, model, device, calculate_accuracy=False):
    model.eval()
    correct = 0
    total = 0
    predictions = []
    total_loss = 0
    criterion = torch.nn.CrossEntropyLoss()
    with torch.no_grad():
        for data in tqdm(data_loader, desc="Iterating eval graphs", unit="batch"):
            data = data.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            
            if calculate_accuracy:
                correct += (pred == data.y).sum().item()
                total += data.y.size(0)
                total_loss += criterion(output, data.y).item()
            else:
                predictions.extend(pred.cpu().numpy())
    if calculate_accuracy:
        accuracy = correct / total
        return  total_loss / len(data_loader),accuracy
    return predictions

In [13]:
def save_predictions(predictions, test_path):
    script_dir = os.getcwd() 
    submission_folder = os.path.join(script_dir, "submission")
    test_dir_name = os.path.basename(os.path.dirname(test_path))
    
    os.makedirs(submission_folder, exist_ok=True)
    
    output_csv_path = os.path.join(submission_folder, f"testset_{test_dir_name}.csv")
    
    test_graph_ids = list(range(len(predictions)))
    output_df = pd.DataFrame({
        "id": test_graph_ids,
        "pred": predictions
    })
    
    output_df.to_csv(output_csv_path, index=False)
    print(f"Predictions saved to {output_csv_path}")

In [14]:
def plot_training_progress(train_losses, train_accuracies, output_dir):
    epochs = range(1, len(train_losses) + 1)
    plt.figure(figsize=(12, 6))

    # Plot loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label="Training Loss", color='blue')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss per Epoch')

    # Plot accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label="Training Accuracy", color='green')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training Accuracy per Epoch')

    # Save plots in the current directory
    os.makedirs(output_dir, exist_ok=True)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "training_progress.png"))
    plt.close()

In [15]:
def get_user_input(prompt, default=None, required=False, type_cast=str):

    while True:
        user_input = input(f"{prompt} [{default}]: ")
        
        if user_input == "" and required:
            print("This field is required. Please enter a value.")
            continue
        
        if user_input == "" and default is not None:
            return default
        
        if user_input == "" and not required:
            return None
        
        try:
            return type_cast(user_input)
        except ValueError:
            print(f"Invalid input. Please enter a valid {type_cast.__name__}.")

In [16]:
def get_arguments():
    args = {}
    args['train_path'] = get_user_input("Path to the training dataset (optional)")
    args['test_path'] = get_user_input("Path to the test dataset", required=True)
    args['num_checkpoints'] = get_user_input("Number of checkpoints to save during training", type_cast=int)
    args['device'] = get_user_input("Which GPU to use if any", default=1, type_cast=int)
    args['gnn'] = get_user_input("GNN type (gin, gin-virtual, gcn, gcn-virtual)", default='gin')
    args['res'] = get_user_input("Residuals in GNN ? (1 yes, 0 no)", default=0)
    args['drop_ratio'] = get_user_input("Dropout ratio", default=0.0, type_cast=float)
    args['num_layer'] = get_user_input("Number of GNN message passing layers", default=5, type_cast=int)
    args['emb_dim'] = get_user_input("Dimensionality of hidden units in GNNs", default=300, type_cast=int)
    args['batch_size'] = get_user_input("Input batch size for training", default=32, type_cast=int)
    args['epochs'] = get_user_input("Number of epochs to train", default=10, type_cast=int)
    args['baseline_mode'] = get_user_input("Baseline mode: 1 (CE), 2 (Noisy CE)", default=1, type_cast=int)
    args['noise_prob'] = get_user_input("Noise probability p (used if baseline_mode=2)", default=0.2, type_cast=float)
    args['pooling'] = get_user_input("type of pooling (sum, mean, max, attention, set2set)", default='mean')
    
    return argparse.Namespace(**args)


In [17]:
def populate_args(args):
    print("Arguments received:")
    for key, value in vars(args).items():
        print(f"{key}: {value}")

Path to the training dataset (optional) [None]:  datasets/A/train.json.gz
Path to the test dataset [None]:  datasets/A/test.json.gz
Number of checkpoints to save during training [None]:  10
Which GPU to use if any [1]:  1
GNN type (gin, gin-virtual, gcn, gcn-virtual) [gin]:  gcn
Residuals in GNN ? (1 yes, 0 no) [0]:  1
Dropout ratio [0.0]:  0.2
Number of GNN message passing layers [5]:  
Dimensionality of hidden units in GNNs [300]:  
Input batch size for training [32]:  
Number of epochs to train [10]:  
Baseline mode: 1 (CE), 2 (Noisy CE) [1]:  1
Noise probability p (used if baseline_mode=2) [0.2]:  
type of pooling (sum, mean, max, attention, set2set) [mean]:  


Arguments received:
train_path: datasets/A/train.json.gz
test_path: datasets/A/test.json.gz
num_checkpoints: 10
device: 1
gnn: gcn
res: 1
drop_ratio: 0.2
num_layer: 5
emb_dim: 300
batch_size: 32
epochs: 10
baseline_mode: 1
noise_prob: 0.2
pooling: mean


In [None]:
args = get_arguments()
populate_args(args)

In [18]:
class NoisyCrossEntropyLoss(torch.nn.Module):
    def __init__(self, p_noisy):
        super().__init__()
        self.p = p_noisy
        self.ce = torch.nn.CrossEntropyLoss(reduction='none')

    def forward(self, logits, targets):
        losses = self.ce(logits, targets)
        weights = (1 - self.p) + self.p * (1 - torch.nn.functional.one_hot(targets, num_classes=logits.size(1)).float().sum(dim=1))
        return (losses * weights).mean()

In [19]:
script_dir = os.getcwd() 
# device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_checkpoints = args.num_checkpoints if args.num_checkpoints else 3
    
if args.gnn == 'gin':
    model1 = GNN(gnn_type='gin', num_class=6, num_layer=args.num_layer, emb_dim=args.emb_dim, drop_ratio=args.drop_ratio, virtual_node=False, residual = True if args.res == 1 else False, graph_pooling=args.pooling).to(device)
    model2 = GNN(gnn_type='gin', num_class=6, num_layer=args.num_layer, emb_dim=args.emb_dim, drop_ratio=args.drop_ratio, virtual_node=False, residual = True if args.res == 1 else False, graph_pooling=args.pooling).to(device)
elif args.gnn == 'gin-virtual':
    model1 = GNN(gnn_type='gin', num_class=6, num_layer=args.num_layer, emb_dim=args.emb_dim, drop_ratio=args.drop_ratio, virtual_node=True, residual = True if args.res == 1 else False, graph_pooling=args.pooling).to(device)
    model2 = GNN(gnn_type='gin', num_class=6, num_layer=args.num_layer, emb_dim=args.emb_dim, drop_ratio=args.drop_ratio, virtual_node=True, residual = True if args.res == 1 else False, graph_pooling=args.pooling).to(device)
elif args.gnn == 'gcn':
    model1 = GNN(gnn_type='gcn', num_class=6, num_layer=args.num_layer, emb_dim=args.emb_dim, drop_ratio=args.drop_ratio, virtual_node=False, residual = True if args.res == 1 else False, graph_pooling=args.pooling).to(device)
    model2 = GNN(gnn_type='gcn', num_class=6, num_layer=args.num_layer, emb_dim=args.emb_dim, drop_ratio=args.drop_ratio, virtual_node=False, residual = True if args.res == 1 else False, graph_pooling=args.pooling).to(device)
elif args.gnn == 'gcn-virtual':
    model1 = GNN(gnn_type='gcn', num_class=6, num_layer=args.num_layer, emb_dim=args.emb_dim, drop_ratio=args.drop_ratio, virtual_node=True, residual = True if args.res == 1 else False, graph_pooling=args.pooling).to(device)
    model2 = GNN(gnn_type='gcn', num_class=6, num_layer=args.num_layer, emb_dim=args.emb_dim, drop_ratio=args.drop_ratio, virtual_node=True, residual = True if args.res == 1 else False, graph_pooling=args.pooling).to(device)
else:
    raise ValueError('Invalid GNN type')
    
optimizer1 = torch.optim.Adam(model1.parameters(), lr=0.001)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=0.001)
# criterion = torch.nn.CrossEntropyLoss()
if args.baseline_mode == 2:
    criterion = NoisyCrossEntropyLoss(args.noise_prob)
else:
    criterion = torch.nn.CrossEntropyLoss(reduction='none')

In [20]:
test_dir_name = os.path.basename(os.path.dirname(args.test_path))
logs_folder = os.path.join(script_dir, "logs", test_dir_name)
log_file = os.path.join(logs_folder, "training.log")
os.makedirs(os.path.dirname(log_file), exist_ok=True)
logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s - %(message)s')
logging.getLogger().addHandler(logging.StreamHandler())

checkpoint_path = os.path.join(script_dir, "checkpoints", f"model_{test_dir_name}_best.pth")
checkpoints_folder = os.path.join(script_dir, "checkpoints", test_dir_name)
os.makedirs(checkpoints_folder, exist_ok=True)


In [21]:
if os.path.exists(checkpoint_path) and not args.train_path:
    model1.load_state_dict(torch.load(checkpoint_path))
    print(f"Loaded best model from {checkpoint_path}")

In [22]:
def train_coteaching(train_loader, model1, model2, optimizer1, optimizer2, criterion, device, forget_rate, epoch, num_epochs):
    model1.train()
    model2.train()

    total_correct1 = total_correct2 = total_samples = 0
    running_loss1 = 0.0
    running_loss2 = 0.0

    for batch in train_loader:
        batch = batch.to(device, non_blocking=True)

        out1 = model1(batch)
        out2 = model2(batch)

        # Perte par échantillon (shape: [batch_size])
        loss1 = criterion(out1, batch.y)
        loss2 = criterion(out2, batch.y)

        remember_rate = 1.0 - forget_rate
        num_remember = int(remember_rate * batch.y.size(0))

        _, idx1 = loss1.topk(num_remember, largest=False)
        _, idx2 = loss2.topk(num_remember, largest=False)

        loss1_update = criterion(out1[idx2], batch.y[idx2]).mean()
        loss2_update = criterion(out2[idx1], batch.y[idx1]).mean()

        optimizer1.zero_grad()
        loss1_update.backward()
        optimizer1.step()

        optimizer2.zero_grad()
        loss2_update.backward()
        optimizer2.step()

        # On calcule la loss moyenne pondérée par la taille du batch (pour l’epoch)
        running_loss1 += loss1_update.item() * batch.y.size(0)
        running_loss2 += loss2_update.item() * batch.y.size(0)

        # Calcul des bonnes prédictions
        total_correct1 += (out1.argmax(dim=1) == batch.y).sum().item()
        total_correct2 += (out2.argmax(dim=1) == batch.y).sum().item()
        total_samples += batch.y.size(0)

    avg_loss1 = running_loss1 / total_samples
    avg_loss2 = running_loss2 / total_samples
    acc1 = total_correct1 / total_samples
    acc2 = total_correct2 / total_samples

    return acc1, acc2, avg_loss1, avg_loss2

In [23]:
if args.train_path:
    full_dataset = GraphDataset(args.train_path, transform=add_degree_feature)
    val_size = int(0.2 * len(full_dataset))
    train_size = len(full_dataset) - val_size

    generator = torch.Generator().manual_seed(12)
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size], generator=generator)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)

    num_epochs = args.epochs
    best_val_accuracy = 0.0

    train_losses_model1 = []
    train_losses_model2 = []
    train_accuracies1 = []
    train_accuracies2 = []
    val_losses_model1 = []
    val_losses_model2 = []
    val_accuracies1 = []
    val_accuracies2 = []


    if num_checkpoints > 1:
        checkpoint_intervals = [int((i + 1) * num_epochs / num_checkpoints) for i in range(num_checkpoints)]
    else:
        checkpoint_intervals = [num_epochs]

    def get_forget_rate(epoch, total_epochs, max_forget=0.2):
        return min(max_forget, max_forget * epoch / (total_epochs * 0.5))

    for epoch in range(num_epochs):
        forget_rate = get_forget_rate(epoch, num_epochs)

        # Phase d'entraînement
        train_acc1, train_acc2, train_loss1, train_loss2 = train_coteaching(
            train_loader, model1, model2, optimizer1, optimizer2,
            criterion, device, forget_rate, epoch, num_epochs
        )

        # Phase de validation
        val_loss1, val_acc1 = evaluate(val_loader, model1, device, calculate_accuracy=True)
        val_loss2, val_acc2 = evaluate(val_loader, model2, device, calculate_accuracy=True)

        # Stocker les métriques
        train_losses_model1.append(train_loss1)
        train_losses_model2.append(train_loss2)
        train_accuracies1.append(train_acc1)
        train_accuracies2.append(train_acc2)
    
        val_losses_model1.append(val_loss1)
        val_losses_model2.append(val_loss2)
        val_accuracies1.append(val_acc1)
        val_accuracies2.append(val_acc2)
    
        print(f"Epoch {epoch+1}/{num_epochs} | Forget rate: {forget_rate:.3f}")
        print(f"Train Loss Model1: {train_loss1:.4f}, Acc: {train_acc1:.4f}")
        print(f"Train Loss Model2: {train_loss2:.4f}, Acc: {train_acc2:.4f}")
        print(f"Val   Loss Model1: {val_loss1:.4f}, Acc: {val_acc1:.4f}")
        print(f"Val   Loss Model2: {val_loss2:.4f}, Acc: {val_acc2:.4f}")
        
        if val_acc1 > best_val_accuracy:
            best_val_accuracy = val_acc1
            torch.save(model1.state_dict(), os.path.join(checkpoints_folder, f"model_{test_dir_name}"))
            print(f"[Co-Teaching] Best model1 updated and saved at {checkpoints_folder}/model_{test_dir_name}")

    # Assuming you have train_losses1 and train_losses2 collected similarly to train_accuracies
    plot_training_progress(train_losses_model1, train_accuracies1, os.path.join(logs_folder, "plots_model1"))
    plot_training_progress(val_losses_model1, val_accuracies1, os.path.join(logs_folder, "plotsVal_model1"))


Iterating eval graphs: 100%|██████████| 71/71 [00:20<00:00,  3.50batch/s]
Iterating eval graphs: 100%|██████████| 71/71 [00:33<00:00,  2.15batch/s]


Epoch 1/10 | Forget rate: 0.000
Train Loss Model1: 1.5859, Acc: 0.3896
Train Loss Model2: 1.5714, Acc: 0.3965
Val   Loss Model1: 1.6297, Acc: 0.4215
Val   Loss Model2: 1.5086, Acc: 0.4521
[Co-Teaching] Best model1 updated and saved at /home/onyxia/work/DL-Hackathon/hackaton/checkpoints/A/model_A


Iterating eval graphs: 100%|██████████| 71/71 [00:19<00:00,  3.71batch/s]
Iterating eval graphs: 100%|██████████| 71/71 [00:19<00:00,  3.61batch/s]


Epoch 2/10 | Forget rate: 0.040
Train Loss Model1: 1.3249, Acc: 0.4576
Train Loss Model2: 1.3092, Acc: 0.4697
Val   Loss Model1: 1.5796, Acc: 0.4517
Val   Loss Model2: 1.4711, Acc: 0.4827
[Co-Teaching] Best model1 updated and saved at /home/onyxia/work/DL-Hackathon/hackaton/checkpoints/A/model_A


Iterating eval graphs: 100%|██████████| 71/71 [00:21<00:00,  3.35batch/s]
Iterating eval graphs: 100%|██████████| 71/71 [00:16<00:00,  4.34batch/s]


Epoch 3/10 | Forget rate: 0.080
Train Loss Model1: 1.1789, Acc: 0.4927
Train Loss Model2: 1.1678, Acc: 0.4936
Val   Loss Model1: 1.5059, Acc: 0.5044
Val   Loss Model2: 1.4609, Acc: 0.5315
[Co-Teaching] Best model1 updated and saved at /home/onyxia/work/DL-Hackathon/hackaton/checkpoints/A/model_A


Iterating eval graphs: 100%|██████████| 71/71 [00:27<00:00,  2.54batch/s]
Iterating eval graphs: 100%|██████████| 71/71 [00:25<00:00,  2.78batch/s]


Epoch 4/10 | Forget rate: 0.120
Train Loss Model1: 1.0566, Acc: 0.5135
Train Loss Model2: 1.0518, Acc: 0.5139
Val   Loss Model1: 1.5534, Acc: 0.5381
Val   Loss Model2: 1.6127, Acc: 0.5106
[Co-Teaching] Best model1 updated and saved at /home/onyxia/work/DL-Hackathon/hackaton/checkpoints/A/model_A


Iterating eval graphs: 100%|██████████| 71/71 [00:24<00:00,  2.91batch/s]
Iterating eval graphs: 100%|██████████| 71/71 [00:24<00:00,  2.89batch/s]


Epoch 5/10 | Forget rate: 0.160
Train Loss Model1: 0.8772, Acc: 0.5289
Train Loss Model2: 0.8721, Acc: 0.5303
Val   Loss Model1: 1.7662, Acc: 0.5177
Val   Loss Model2: 1.6727, Acc: 0.5505


Iterating eval graphs: 100%|██████████| 71/71 [00:20<00:00,  3.42batch/s]
Iterating eval graphs: 100%|██████████| 71/71 [00:20<00:00,  3.48batch/s]


Epoch 6/10 | Forget rate: 0.200
Train Loss Model1: 0.7708, Acc: 0.5455
Train Loss Model2: 0.7764, Acc: 0.5458
Val   Loss Model1: 1.7545, Acc: 0.5199
Val   Loss Model2: 1.8308, Acc: 0.5230


Iterating eval graphs: 100%|██████████| 71/71 [00:20<00:00,  3.48batch/s]
Iterating eval graphs: 100%|██████████| 71/71 [00:20<00:00,  3.54batch/s]


Epoch 7/10 | Forget rate: 0.200
Train Loss Model1: 0.7305, Acc: 0.5663
Train Loss Model2: 0.7383, Acc: 0.5602
Val   Loss Model1: 1.7026, Acc: 0.5820
Val   Loss Model2: 1.7613, Acc: 0.5554
[Co-Teaching] Best model1 updated and saved at /home/onyxia/work/DL-Hackathon/hackaton/checkpoints/A/model_A


Iterating eval graphs: 100%|██████████| 71/71 [00:20<00:00,  3.49batch/s]
Iterating eval graphs: 100%|██████████| 71/71 [00:20<00:00,  3.50batch/s]


Epoch 8/10 | Forget rate: 0.200
Train Loss Model1: 0.6826, Acc: 0.5752
Train Loss Model2: 0.6909, Acc: 0.5776
Val   Loss Model1: 1.7896, Acc: 0.5878
Val   Loss Model2: 1.9017, Acc: 0.5603
[Co-Teaching] Best model1 updated and saved at /home/onyxia/work/DL-Hackathon/hackaton/checkpoints/A/model_A


Iterating eval graphs: 100%|██████████| 71/71 [00:19<00:00,  3.62batch/s]
Iterating eval graphs: 100%|██████████| 71/71 [00:20<00:00,  3.46batch/s]


Epoch 9/10 | Forget rate: 0.200
Train Loss Model1: 0.6476, Acc: 0.5902
Train Loss Model2: 0.6629, Acc: 0.5878
Val   Loss Model1: 1.9386, Acc: 0.5811
Val   Loss Model2: 1.7902, Acc: 0.5971


Iterating eval graphs: 100%|██████████| 71/71 [00:20<00:00,  3.52batch/s]
Iterating eval graphs: 100%|██████████| 71/71 [00:20<00:00,  3.49batch/s]


Epoch 10/10 | Forget rate: 0.200
Train Loss Model1: 0.6268, Acc: 0.5966
Train Loss Model2: 0.6318, Acc: 0.5963
Val   Loss Model1: 1.6504, Acc: 0.5900
Val   Loss Model2: 1.8235, Acc: 0.5842
[Co-Teaching] Best model1 updated and saved at /home/onyxia/work/DL-Hackathon/hackaton/checkpoints/A/model_A
