In [1]:
from datasets import load_dataset

from tqdm import tqdm  # Importar tqdm para la barra de progreso

import numpy as np

import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
import torch.nn as nn

from torchvision import transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ds = load_dataset("martingrzzler/kanjis2radicals")
""" Item in ds['train']
 Dataset({
    features: ['kanji_image', 'meta'],
    num_rows: 2027
})

{'kanji_image': Image(mode=None, decode=True, id=None), 
 'meta': {'id': Value(dtype='int32', id=None), 
          'characters': Value(dtype='string', id=None), 
          'meanings': Value(dtype='string', id=None), 
          'radicals': Sequence(feature={'characters': Value(dtype='string', id=None), 
                                        'id': Value(dtype='int32', id=None), 
                                        'slug': Value(dtype='string', id=None)
                                        }, length=-1, id=None)
         }
}
"""
# Split the dataset into train and test (80-20 split by default)
ds_split = ds['train'].train_test_split(test_size=0.2)

# Access the train and test subsets
ds_train = ds_split['train']
ds_test = ds_split['test']

In [3]:
# Extract all unique radicals
unique_radicals = set()
for element in ds['train']:
    radical_ids = element['meta']['radicals']['id']
    unique_radicals.update(radical_ids)

# Sort the unique radicals
sorted_unique_radicals = sorted(unique_radicals)

# Count total unique radicals
total_unique_radicals = len(sorted_unique_radicals)

print(f"Total unique radicals: {total_unique_radicals}")
print(f"Sorted unique radicals: {sorted_unique_radicals}")

Total unique radicals: 478
Sorted unique radicals: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215

In [4]:
def label_vector(ds):
    # Crear un mapeo de radicales únicos a índices
    radical_to_index = {radical: idx for idx, radical in enumerate(unique_radicals)}

    # Transformar los identificadores de radicales en vectores binarios
    def generate_label_vector(radical_ids, total_radicals):
        label_vector = [0] * total_radicals
        for radical_id in radical_ids:
            if radical_id in radical_to_index:
                label_vector[radical_to_index[radical_id]] = 1
        return label_vector

    # Generar labels para el dataset de entrenamiento
    total_unique_radicals = len(unique_radicals)
    labels = []

    for element in ds:
        radical_ids = element['meta']['radicals']['id']
        label_vector = generate_label_vector(radical_ids, total_unique_radicals)
        labels.append(label_vector)
    return labels


In [5]:
train_labels = label_vector(ds_train)
test_labels = label_vector(ds_test)

In [6]:
image_transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize images to a fixed size
    transforms.ToTensor(),         # Convert image to tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1] for RGB
])

In [7]:
def sample_batch(ds_image, ds_label):
    # Define a transformation pipeline for the images
    batch_images = []
    batch_labels = []

    for i, _ in enumerate(ds_image):
        # Example image path and label (replace with actual paths)
        radical_ids = ds_label[i]
        
        # Load and transform the image
        image = ds_image[i]['kanji_image'].convert("RGB")
        image_tensor = image_transform(image)
        
        # Convert radical IDs to tensor
        label_tensor = torch.tensor(radical_ids, dtype=torch.long)
        
        # Append to batch lists
        batch_images.append(image_tensor)
        batch_labels.append(label_tensor)

    # Stack into a batch (list of tensors to tensor batch)
    batch_images = torch.stack(batch_images)
    batch_labels = torch.nn.utils.rnn.pad_sequence(batch_labels, batch_first=True, padding_value=-1)

    print(f"Batch Images Shape: {batch_images.shape}")  # [batch_size, channels, height, width]
    print(f"Batch Labels Shape: {batch_labels.shape}")  # [batch_size, max_seq_len]
    return batch_images, batch_labels


In [8]:

train_img, train_labels = sample_batch(ds_train, train_labels)
test_img, test_labels = sample_batch(ds_test, test_labels)

Batch Images Shape: torch.Size([1621, 3, 128, 128])
Batch Labels Shape: torch.Size([1621, 478])
Batch Images Shape: torch.Size([406, 3, 128, 128])
Batch Labels Shape: torch.Size([406, 478])


In [9]:
# Define the collate_fn to pad sequences and stack the images
def collate_fn(batch):
    images = torch.stack([item[0] for item in batch])  # Stack all images in the batch
    labels = [item[1] for item in batch]  # Get the labels (radical IDs) for each image
    
    # Pad the label sequences
    labels_padded = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-1)
    
    return images, labels_padded

In [10]:
class KanjiRadicalDataset(Dataset):
    def __init__(self, images, labels):
        """
        Args:
            images (list): List of file paths to the Kanji images.
            labels (list): List of corresponding radical ID lists for each image.
        """
        self.images = images
        self.labels = labels

    def __len__(self):
        # Return the number of samples
        return len(self.images)

    def __getitem__(self, idx):
        """
        Args:
            idx (int): Index of the sample to retrieve.
        
        Returns:
            dict: A dictionary with 'image' and 'label' tensors.
        """
        image = self.images[idx]
        label = self.labels[idx]
        
        return image, label


In [11]:
# Create DataLoader for training and testing
train_dataset = KanjiRadicalDataset(images=train_img, labels=train_labels)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=lambda batch: collate_fn(batch))

test_dataset = KanjiRadicalDataset(images=test_img, labels=test_labels)

test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=lambda batch: collate_fn(batch))


In [12]:
class conv_block_nested(nn.Module):

    def __init__(self, in_ch, mid_ch, out_ch):
        super(conv_block_nested, self).__init__()
        self.activation = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1, bias=True)
        self.bn1 = nn.BatchNorm2d(mid_ch)
        self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1, bias=True)
        self.bn2 = nn.BatchNorm2d(out_ch)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.activation(x)

        x = self.conv2(x)
        x = self.bn2(x)
        output = self.activation(x)

        return output

In [13]:
class Nested_UNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=1):
        super(Nested_UNet, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv0_0 = conv_block_nested(in_ch, filters[0], filters[0])
        self.conv1_0 = conv_block_nested(filters[0], filters[1], filters[1])
        self.conv2_0 = conv_block_nested(filters[1], filters[2], filters[2])
        self.conv3_0 = conv_block_nested(filters[2], filters[3], filters[3])
        self.conv4_0 = conv_block_nested(filters[3], filters[4], filters[4])

        self.conv0_1 = conv_block_nested(filters[0] + filters[1], filters[0], filters[0])
        self.conv1_1 = conv_block_nested(filters[1] + filters[2], filters[1], filters[1])
        self.conv2_1 = conv_block_nested(filters[2] + filters[3], filters[2], filters[2])
        self.conv3_1 = conv_block_nested(filters[3] + filters[4], filters[3], filters[3])

        self.conv0_2 = conv_block_nested(filters[0]*2 + filters[1], filters[0], filters[0])
        self.conv1_2 = conv_block_nested(filters[1]*2 + filters[2], filters[1], filters[1])
        self.conv2_2 = conv_block_nested(filters[2]*2 + filters[3], filters[2], filters[2])

        self.conv0_3 = conv_block_nested(filters[0]*3 + filters[1], filters[0], filters[0])
        self.conv1_3 = conv_block_nested(filters[1]*3 + filters[2], filters[1], filters[1])

        self.conv0_4 = conv_block_nested(filters[0]*4 + filters[1], filters[0], filters[0])

        self.final = nn.Sequential(
            nn.Conv2d(filters[0], out_ch, kernel_size=1),
            nn.AdaptiveAvgPool2d((1, 1))
        )

    def forward(self, x):
        x0_0 = self.conv0_0(x)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.Up(x1_0)], 1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.Up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.Up(x1_1)], 1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.Up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.Up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.Up(x1_2)], 1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.Up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.Up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.Up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.Up(x1_3)], 1))

        output = self.final(x0_4)
        output = output.view(output.size(0), -1)  # [batch_size, out_ch]
        return output


In [14]:
# Calcular el peso basado en la proporción de ceros y unos
# Calcular pos_weight basado en el dataset (num_negatives / num_positives)
# Convertir las etiquetas en un array de numpy para facilidad de conteo
labels_array = np.array(train_labels)  # Asegúrate de que train_labels es una lista o array-like

# Contar la cantidad de ceros y unos
num_negatives = np.sum(labels_array == 0)  # Cantidad de ceros
num_positives = np.sum(labels_array == 1)  # Cantidad de unos

pos_weight = torch.tensor([num_negatives / num_positives])

# Model, Loss, Optimizer
num_radicals = 478  # Total unique radicals in the dataset
model = Nested_UNet(3, num_radicals)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)  # Multi-label classification con balance
optimizer = Adam(model.parameters(), lr=1e-4)


# Training Loop
for epoch in range(2):  # Número de epochs
    model.train()
    epoch_loss = 0  # Pérdida acumulada por epoch

    # Barra de progreso para batches
    with tqdm(train_loader, desc=f"Epoch {epoch+1}") as pbar:
        for batch in pbar:
            images, targets = batch  # Images: [batch_size, 3, 128, 128], Targets: [batch_size, num_radicals]
            optimizer.zero_grad()

            # Forward pass
            predictions = model(images)  # Output: [batch_size, num_radicals]

            # Calculate loss
            loss = criterion(predictions, targets.float())  # Convertir targets a float
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            pbar.set_postfix({"Loss": loss.item()})  # Actualizar barra de progreso con pérdida actual

    print(f"Epoch {epoch+1}, Average Loss: {epoch_loss / len(train_loader)}")


Epoch 1: 100%|██████████| 102/102 [52:46<00:00, 31.05s/it, Loss=1.33]


Epoch 1, Average Loss: 1.3468477936352001


Epoch 2: 100%|██████████| 102/102 [51:37<00:00, 30.37s/it, Loss=1.17]

Epoch 2, Average Loss: 1.2750413417816162





In [15]:
# Evaluación con ajuste de threshold
threshold = 0.7  # Ajustar el umbral para clasificar correctamente los unos
model.eval()
all_predictions = []
all_targets = []

with torch.no_grad():
    for batch in test_loader:  # test_loader contiene los datos de validación
        images, targets = batch
        
        raw_outputs = model(images)  # Predicciones en bruto
        predictions = (torch.sigmoid(raw_outputs) > threshold).float()  # Aplicar umbral
        
        all_predictions.append(predictions.cpu())  # Guardar predicciones
        all_targets.append(targets.cpu())          # Guardar objetivos reales


In [18]:
# Calcular métricas adicionales (F1 Score, Recall)
from sklearn.metrics import f1_score, recall_score

all_predictions = torch.cat(all_predictions).numpy()
all_targets = torch.cat(all_targets).numpy()

f1 = f1_score(all_targets, all_predictions, average="macro")
recall = recall_score(all_targets, all_predictions, average="macro")

print(f"F1 Score: {f1:.4f}, Recall: {recall:.4f}")

F1 Score: 0.0062, Recall: 0.0135


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [19]:
print(all_predictions[0])

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.
 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 1. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.

In [20]:
print(all_targets[0])


[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
