In [1]:
from google.colab import drive
drive.mount('/content/drive')
import os

# 指定 ZIP 檔案路徑和解壓縮目標資料夾
zip_file_path = "/content/drive/MyDrive/nycu-i-al-i-ml-2024-seq-2-seq-and-attention.zip"
output_folder = "/content/unzipped_files"

# 3. 創建解壓縮資料夾 (如果不存在)
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

# 4. 解壓縮 ZIP 檔案
!unzip -o "$zip_file_path" -d "$output_folder"

# 5. 確認解壓縮結果
print(f"ZIP 檔案已解壓縮至：{output_folder}")
print("檔案列表：")
!ls "$output_folder"

Mounted at /content/drive
Archive:  /content/drive/MyDrive/nycu-i-al-i-ml-2024-seq-2-seq-and-attention.zip
  inflating: /content/unzipped_files/kaldi-taiwanese-asr/lexicon.txt  
  inflating: /content/unzipped_files/kaldi-taiwanese-asr/sample.csv  
  inflating: /content/unzipped_files/kaldi-taiwanese-asr/test/1.wav  
  inflating: /content/unzipped_files/kaldi-taiwanese-asr/test/10.wav  
  inflating: /content/unzipped_files/kaldi-taiwanese-asr/test/100.wav  
  inflating: /content/unzipped_files/kaldi-taiwanese-asr/test/101.wav  
  inflating: /content/unzipped_files/kaldi-taiwanese-asr/test/102.wav  
  inflating: /content/unzipped_files/kaldi-taiwanese-asr/test/103.wav  
  inflating: /content/unzipped_files/kaldi-taiwanese-asr/test/104.wav  
  inflating: /content/unzipped_files/kaldi-taiwanese-asr/test/105.wav  
  inflating: /content/unzipped_files/kaldi-taiwanese-asr/test/106.wav  
  inflating: /content/unzipped_files/kaldi-taiwanese-asr/test/107.wav  
  inflating: /content/unzipped_file

In [None]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import numpy as np
from sklearn.model_selection import train_test_split
!pip install torchaudio torch comet-ml

In [8]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import numpy as np
from sklearn.model_selection import train_test_split

##########################
# 1. Lexicon & TextTransform
##########################

def parse_lexicon(lexicon_path):
    """
    Parse the lexicon.txt file and build a phoneme_map dict.
    This is the same as your original code.
    """
    phoneme_map = {}
    index = 0
    with open(lexicon_path, 'r') as f:
        for line in f:
            parts = line.strip().split()
            phoneme = parts[0]
            if phoneme not in phoneme_map:
                phoneme_map[phoneme] = index
                index += 1
    # Add <SPACE> to the map
    phoneme_map['<SPACE>'] = index
    return phoneme_map

class TextTransform:
    """
    Same as your original code (with minor clarifications).
    """
    def __init__(self, lexicon_path):
        self.phoneme_map = parse_lexicon(lexicon_path)
        self.index_map = {v: k for k, v in self.phoneme_map.items()}
        # Ensure that the blank label index is the largest
        self.blank_label = max(self.index_map.keys())

    def text_to_int(self, text):
        int_sequence = []
        phonemes = text.split()  # Split by spaces
        for phoneme in phonemes:
            if phoneme in self.phoneme_map:
                int_sequence.append(self.phoneme_map[phoneme])
        return int_sequence

    def int_to_text(self, labels):
        string = []
        for i in labels:
            if i == self.blank_label:
                # blank symbol
                string.append('')
            elif i in self.index_map:
                string.append(self.index_map[i])
        # Join with space
        return ' '.join([s for s in string if s != ''])


##########################
# 2. Audio Transforms (same as original)
##########################

train_audio_transforms = nn.Sequential(
    torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
    torchaudio.transforms.FrequencyMasking(freq_mask_param=15),
    torchaudio.transforms.TimeMasking(time_mask_param=35)
)

valid_audio_transforms = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128)


##########################
# (NEW) 2A. Custom Dataset
##########################

class TaiwaneseDataset(data.Dataset):
    """
    A custom Dataset that:
      - Expects a dataframe with columns 'id', 'text'
      - Loads the corresponding WAV from the appropriate folder
      - Applies the specified audio transforms
      - Converts the transcript to an integer sequence

    If 'subset' == 'test', it will read files from 'test' folder and produce dummy labels.
    """
    def __init__(self, dataframe, data_path, text_transform, audio_transforms=None, subset='train'):
        self.dataframe = dataframe.reset_index(drop=True)
        self.data_path = data_path
        self.text_transform = text_transform
        self.audio_transforms = audio_transforms
        self.subset = subset

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        # Retrieve row
        row = self.dataframe.iloc[idx]
        file_id = str(row["id"])

        # For training/valid sets, there's a 'text' column in row
        utterance = row["text"] if "text" in row else ""

        # Depending on subset, pick the folder
        folder_name = 'train' if self.subset != 'test' else 'test'
        audio_file_path = os.path.join(self.data_path, folder_name, file_id + ".wav")

        # Load waveform
        waveform, sr = torchaudio.load(audio_file_path)

        # Apply transforms
        if self.audio_transforms is not None:
            spec = self.audio_transforms(waveform).squeeze(0).transpose(0, 1)
        else:
            # default MelSpectrogram if not provided
            spec = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128)(waveform)
            spec = spec.squeeze(0).transpose(0, 1)

        if self.subset != 'test':
            # we have labels
            label = torch.Tensor(self.text_transform.text_to_int(utterance))
            # input_length depends on how your CNN reduces time dimension
            input_length = spec.shape[0] // 2
            label_length = len(label)
        else:
            # For test set, we might not have 'text'
            label = torch.tensor([0], dtype=torch.long)
            input_length = spec.shape[0] // 2
            label_length = 1

        return spec, label, input_length, label_length


##########################
# (NEW) 2B. Custom Collate Fn
##########################

def taiwanese_collate_fn(batch):
    """
    Collate function to handle padding of variable-length spectrograms and labels.

    batch: List of tuples (spec, label, input_len, label_len)
    """
    specs = [b[0] for b in batch]
    labels = [b[1] for b in batch]
    input_lengths = [b[2] for b in batch]
    label_lengths = [b[3] for b in batch]

    # Pad spectrograms to the max time dimension in the batch
    specs = nn.utils.rnn.pad_sequence(specs, batch_first=True)
    specs = specs.unsqueeze(1).transpose(2,3)  # (batch, 1, n_mels, time)

    # Pad labels
    labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)

    return specs, labels, input_lengths, label_lengths


##########################
# 3. Model (DeepSpeech2)
# [No changes from your original code, except we keep them here for completeness]
##########################

class CNNLayerNorm(nn.Module):
    """Layer normalization built for CNNs input"""
    def __init__(self, n_feats):
        super(CNNLayerNorm, self).__init__()
        self.layer_norm = nn.LayerNorm(n_feats)

    def forward(self, x):
        # x (batch, channel, feature, time)
        x = x.transpose(2, 3).contiguous() # (batch, channel, time, feature)
        x = self.layer_norm(x)
        return x.transpose(2, 3).contiguous() # (batch, channel, feature, time)

class ResidualCNN(nn.Module):
    """Residual CNN with Layer Norm + GELU"""
    def __init__(self, in_channels, out_channels, kernel, stride, dropout, n_feats):
        super(ResidualCNN, self).__init__()
        self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel//2)
        self.cnn2 = nn.Conv2d(out_channels, out_channels, kernel, stride, padding=kernel//2)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.layer_norm1 = CNNLayerNorm(n_feats)
        self.layer_norm2 = CNNLayerNorm(n_feats)

    def forward(self, x):
        residual = x  # (batch, channel, feature, time)
        x = self.layer_norm1(x)
        x = F.gelu(x)
        x = self.dropout1(x)
        x = self.cnn1(x)

        x = self.layer_norm2(x)
        x = F.gelu(x)
        x = self.dropout2(x)
        x = self.cnn2(x)
        x += residual
        return x

class BidirectionalGRU(nn.Module):
    def __init__(self, rnn_dim, hidden_size, dropout, batch_first):
        super(BidirectionalGRU, self).__init__()
        self.BiGRU = nn.GRU(
            input_size=rnn_dim, hidden_size=hidden_size,
            num_layers=1, batch_first=batch_first, bidirectional=True
        )
        self.layer_norm = nn.LayerNorm(rnn_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.layer_norm(x)
        x = F.gelu(x)
        x, _ = self.BiGRU(x)
        x = self.dropout(x)
        return x

class SpeechRecognitionModel(nn.Module):
    def __init__(self, n_cnn_layers, n_rnn_layers, rnn_dim, n_class, n_feats, stride=2, dropout=0.1):
        super(SpeechRecognitionModel, self).__init__()
        n_feats = n_feats // 2
        self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3//2)

        # n residual cnn layers with filter size of 32
        self.rescnn_layers = nn.Sequential(*[
            ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats)
            for _ in range(n_cnn_layers)
        ])
        self.fully_connected = nn.Linear(n_feats*32, rnn_dim)

        self.birnn_layers = nn.Sequential(*[
            BidirectionalGRU(
                rnn_dim=rnn_dim if i==0 else rnn_dim*2,
                hidden_size=rnn_dim, dropout=dropout, batch_first=i==0
            )
            for i in range(n_rnn_layers)
        ])

        self.classifier = nn.Sequential(
            nn.Linear(rnn_dim*2, rnn_dim),  # birnn returns rnn_dim*2
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(rnn_dim, n_class)
        )

    def forward(self, x):
        x = self.cnn(x)
        x = self.rescnn_layers(x)
        sizes = x.size()
        x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # (batch, feature, time)
        x = x.transpose(1, 2) # (batch, time, feature)
        x = self.fully_connected(x)
        x = self.birnn_layers(x)
        x = self.classifier(x)
        return x


##########################
# 4. Training / Validation routines
# [No major changes except that we do not call data_processing inside the loop]
##########################

def train(model, device, train_loader, criterion, optimizer, scheduler, epoch, iter_meter):
    model.train()
    data_len = len(train_loader.dataset)
    for batch_idx, _data in enumerate(train_loader):
        spectrograms, labels, input_lengths, label_lengths = _data
        spectrograms, labels = spectrograms.to(device), labels.to(device)
        # input_lengths and label_lengths typically must be on CPU or long dtype, depending on PyTorch version:
        input_lengths = torch.tensor(input_lengths, dtype=torch.long)
        label_lengths = torch.tensor(label_lengths, dtype=torch.long)

        optimizer.zero_grad()
        output = model(spectrograms)
        # for CTC: (time, batch, n_class)
        output = F.log_softmax(output, dim=2).transpose(0, 1)

        loss = criterion(output, labels, input_lengths, label_lengths)
        loss.backward()

        optimizer.step()
        scheduler.step()
        iter_meter.step()

        if batch_idx % 100 == 0 or batch_idx == len(train_loader) - 1:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, (batch_idx + 1) * len(spectrograms), data_len,
                100. * (batch_idx + 1) / len(train_loader), loss.item()))

def validate(model, device, val_loader, criterion, iter_meter, text_transform):
    model.eval()
    val_loss = 0
    val_wer = []
    val_lev_dist = []
    with torch.no_grad():
        for batch_idx, _data in enumerate(val_loader):
            spectrograms, labels, input_lengths, label_lengths = _data
            spectrograms, labels = spectrograms.to(device), labels.to(device)
            input_lengths = torch.tensor(input_lengths, dtype=torch.long)
            label_lengths = torch.tensor(label_lengths, dtype=torch.long)

            output = model(spectrograms)
            output = F.log_softmax(output, dim=2).transpose(0, 1)

            loss = criterion(output, labels, input_lengths, label_lengths)
            val_loss += loss.item()

            decoded_preds, decoded_targets = GreedyDecoder(
                output.transpose(0, 1), labels, label_lengths,
                text_transform=text_transform, blank_label=text_transform.blank_label
            )
            for j in range(len(decoded_preds)):
                val_wer.append(wer(decoded_targets[j], decoded_preds[j]))
                val_lev_dist.append(levenshtein_distance(decoded_targets[j], decoded_preds[j]))

    avg_val_loss = val_loss / len(val_loader)
    avg_val_wer = sum(val_wer) / len(val_wer) if len(val_wer) > 0 else 0
    avg_val_lev_dist = sum(val_lev_dist) / len(val_lev_dist) if len(val_lev_dist) > 0 else 0

    print('Validation set: Average loss: {:.4f}, Average WER: {:.4f}, Average Levenshtein Distance: {:.4f}\n'.format(
        avg_val_loss, avg_val_wer, avg_val_lev_dist))

    return avg_val_loss, avg_val_wer, avg_val_lev_dist


##########################
# 4A. Evaluation Metrics (unchanged from your original, just consolidated)
##########################

def _levenshtein_distance(ref, hyp):
    """[No changes, same as original internal function]"""
    m = len(ref)
    n = len(hyp)
    if ref == hyp:
        return 0
    if m == 0:
        return n
    if n == 0:
        return m
    if m < n:
        ref, hyp = hyp, ref
        m, n = n, m
    distance = np.zeros((2, n + 1), dtype=np.int32)
    for j in range(0, n + 1):
        distance[0][j] = j
    for i in range(1, m + 1):
        prev_row_idx = (i - 1) % 2
        cur_row_idx = i % 2
        distance[cur_row_idx][0] = i
        for j in range(1, n + 1):
            if ref[i - 1] == hyp[j - 1]:
                distance[cur_row_idx][j] = distance[prev_row_idx][j - 1]
            else:
                s_num = distance[prev_row_idx][j - 1] + 1
                i_num = distance[cur_row_idx][j - 1] + 1
                d_num = distance[prev_row_idx][j] + 1
                distance[cur_row_idx][j] = min(s_num, i_num, d_num)
    return distance[m % 2][n]

def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '):
    """[No changes, same as original]"""
    if ignore_case:
        reference = reference.lower()
        hypothesis = hypothesis.lower()
    ref_words = reference.split(delimiter)
    hyp_words = hypothesis.split(delimiter)
    edit_distance = _levenshtein_distance(ref_words, hyp_words)
    return float(edit_distance), len(ref_words)

def char_errors(reference, hypothesis, ignore_case=False, remove_space=False):
    """[No changes, same as original]"""
    if ignore_case:
        reference = reference.lower()
        hypothesis = hypothesis.lower()
    join_char = ' '
    if remove_space:
        join_char = ''
    reference = join_char.join(filter(None, reference.split(' ')))
    hypothesis = join_char.join(filter(None, hypothesis.split(' ')))
    edit_distance = _levenshtein_distance(reference, hypothesis)
    return float(edit_distance), len(reference)

def wer(reference, hypothesis, ignore_case=False, delimiter=' '):
    """[No changes, same as original]"""
    edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case, delimiter)
    if ref_len == 0:
        raise ValueError("Reference's word number should be greater than 0.")
    return float(edit_distance) / ref_len

def cer(reference, hypothesis, ignore_case=False, remove_space=False):
    """[No changes, same as original]"""
    edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case, remove_space)
    if ref_len == 0:
        raise ValueError("Length of reference should be greater than 0.")
    return float(edit_distance) / ref_len

def levenshtein_distance(s1, s2):
    """[No changes, same as original]"""
    if len(s1) < len(s2):
        return levenshtein_distance(s2, s1)
    if len(s2) == 0:
        return len(s1)
    previous_row = range(len(s2) + 1)
    for i, c1 in enumerate(s1):
        current_row = [i + 1]
        for j, c2 in enumerate(s2):
            insertions = previous_row[j + 1] + 1
            deletions = current_row[j] + 1
            substitutions = previous_row[j] + (c1 != c2)
            current_row.append(min(insertions, deletions, substitutions))
        previous_row = current_row
    return previous_row[-1]


##########################
# 4B. GreedyDecoder
##########################

def GreedyDecoder(output, labels, label_lengths, text_transform, blank_label=28, collapse_repeated=True):
    """
    Slightly updated to handle when 'labels' might be None (in test mode).
    """
    arg_maxes = torch.argmax(output, dim=2)
    decodes = []
    targets = []

    # If we do have labels, decode them
    if labels is not None and label_lengths is not None:
        for i, args in enumerate(arg_maxes):
            decode = []
            true_text = text_transform.int_to_text(labels[i][:label_lengths[i]].tolist())
            targets.append(true_text)

            for j, index in enumerate(args):
                if index != blank_label:
                    if collapse_repeated and j != 0 and index == args[j-1]:
                        continue
                    decode.append(index.item())
            decodes.append(text_transform.int_to_text(decode))
    else:
        # Test case scenario: no ground truth
        for i, args in enumerate(arg_maxes):
            decode = []
            for j, index in enumerate(args):
                if index != blank_label:
                    if collapse_repeated and j != 0 and index == args[j-1]:
                        continue
                    decode.append(index.item())
            decodes.append(text_transform.int_to_text(decode))

    return decodes, targets


##########################
# 5. Inference (No major changes)
##########################

def infer(model, device, audio_path, text_transform):
    model.eval()
    waveform, sample_rate = torchaudio.load(audio_path)
    spec = valid_audio_transforms(waveform).squeeze(0).transpose(0, 1)
    spec = spec.unsqueeze(0).unsqueeze(1).transpose(2, 3).to(device)  # (1,1,n_mels,time)
    with torch.no_grad():
        output = model(spec)
        output = F.log_softmax(output, dim=2).transpose(0, 1)
        decoded_preds, _ = GreedyDecoder(output.transpose(0, 1), None, None, text_transform, blank_label=text_transform.blank_label)
    return decoded_preds[0]


##########################
# 6. IterMeter, get_device
##########################

class IterMeter(object):
    """Keeps track of total iterations."""
    def __init__(self):
        self.val = 0
    def step(self):
        self.val += 1
    def get(self):
        return self.val

def get_device(use_tpu=False):
    """
    Selects the appropriate device (CPU, GPU, or TPU) based on availability and the `use_tpu` flag.
    [Same logic as your original code, no changes].
    """
    if use_tpu:
        try:
            import torch_xla
            import torch_xla.core.xla_model as xm
            device = xm.xla_device()
            print(f"Using TPU device: {device}")
            return device
        except ImportError:
            print("TPU not found. Using GPU or CPU instead.")
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"Using GPU device: {device}")
    else:
        device = torch.device("cpu")
        print(f"Using CPU device: {device}")
    return device


##########################
# 7. Main
##########################

def main():
    # Hyperparameters
    learning_rate = 5e-4
    batch_size = 2  # start small on memory-limited environment
    epochs = 1    # may adjust for real training
    data_path = "/content/unzipped_files/kaldi-taiwanese-asr"
    lexicon_path = os.path.join(data_path, "lexicon.txt")
    use_tpu = False

    # Select device
    device = get_device(use_tpu)

    # Build the text transform with the lexicon
    text_transform = TextTransform(lexicon_path)

    # We assume n_class = size of phoneme_map. The 'blank_label' is the max index used by the model's CTC
    n_class = len(text_transform.phoneme_map)
    print("Total unique phonemes in lexicon:", n_class)
    print("Blank label index:", text_transform.blank_label)

    hparams = {
        "n_cnn_layers": 3,
        "n_rnn_layers": 5,
        "rnn_dim": 512,
        "n_class": n_class,
        "n_feats": 128,
        "stride": 2,
        "dropout": 0.1,
        "learning_rate": learning_rate,
        "batch_size": batch_size,
        "epochs": epochs
    }

    # Read the main CSV
    csv_path = os.path.join(data_path, 'train-toneless.csv')
    data_df = pd.read_csv(csv_path)

    # Create train/valid/test splits
    train_df, temp_df = train_test_split(data_df, test_size=0.2, random_state=7)
    valid_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=7)

    # Build Datasets
    train_dataset = TaiwaneseDataset(
        train_df, data_path, text_transform,
        audio_transforms=train_audio_transforms,
        subset='train'
    )
    valid_dataset = TaiwaneseDataset(
        valid_df, data_path, text_transform,
        audio_transforms=valid_audio_transforms,
        subset='train'
    )
    # If you have a separate test folder with 346 files (no transcripts):
    # test_dataset = TaiwaneseDataset(test_df, data_path, text_transform,
    #                                 audio_transforms=valid_audio_transforms,
    #                                 subset='test')

    # Build DataLoaders
    train_loader = data.DataLoader(
        dataset=train_dataset,
        batch_size=hparams['batch_size'],
        shuffle=True,
        collate_fn=taiwanese_collate_fn,
        pin_memory=True
    )

    valid_loader = data.DataLoader(
        dataset=valid_dataset,
        batch_size=hparams['batch_size'],
        shuffle=False,
        collate_fn=taiwanese_collate_fn,
        pin_memory=True
    )

    # If you do have a separate test set:
    # test_loader = data.DataLoader(
    #     dataset=test_dataset,
    #     batch_size=hparams['batch_size'],
    #     shuffle=False,
    #     collate_fn=taiwanese_collate_fn,
    #     pin_memory=True
    # )

    # Instantiate model
    model = SpeechRecognitionModel(
        n_cnn_layers=hparams['n_cnn_layers'],
        n_rnn_layers=hparams['n_rnn_layers'],
        rnn_dim=hparams['rnn_dim'],
        n_class=hparams['n_class'],
        n_feats=hparams['n_feats'],
        stride=hparams['stride'],
        dropout=hparams['dropout']
    ).to(device)

    print(model)
    print('Num Model Parameters', sum([param.nelement() for param in model.parameters()]))

    optimizer = optim.AdamW(model.parameters(), hparams['learning_rate'])
    criterion = nn.CTCLoss(blank=text_transform.blank_label).to(device)

    # OneCycleLR for scheduling
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=hparams['learning_rate'],
        steps_per_epoch=len(train_loader),
        epochs=hparams['epochs'],
        anneal_strategy='linear'
    )

    iter_meter = IterMeter()

    # Training loop
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, criterion, optimizer, scheduler, epoch, iter_meter)
        validate(model, device, valid_loader, criterion, iter_meter, text_transform)

    ###################
    # 8. Generate Predictions for Kaggle (test set)
    ###################
    # If you want to produce a submission using the "test" folder that has 346 .wav files:
    # You can skip reading a CSV if the test folder has no transcripts:
    test_folder = os.path.join(data_path, 'test')
    test_files = [f for f in os.listdir(test_folder) if f.endswith('.wav')]

    # Instead of using DataFrame.append, create a list for all predictions
    predictions = []

    model.eval()
    for test_file in test_files:
        file_id = test_file.split('.')[0]
        audio_path = os.path.join(test_folder, test_file)

        # Predict
        prediction = infer(model, device, audio_path, text_transform)

        # Collect your results in a list
        predictions.append({
            'id': file_id,
            'text': prediction
        })

    # Convert list of dicts to DataFrame once at the end
    results_df = pd.DataFrame(predictions)
    results_df.to_csv('submission.csv', index=False)


if __name__ == "__main__":
    main()




Using GPU device: cuda
Total unique phonemes in lexicon: 805
Blank label index: 804
SpeechRecognitionModel(
  (cnn): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (rescnn_layers): Sequential(
    (0): ResidualCNN(
      (cnn1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (cnn2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
      (layer_norm1): CNNLayerNorm(
        (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      )
      (layer_norm2): CNNLayerNorm(
        (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      )
    )
    (1): ResidualCNN(
      (cnn1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (cnn2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inpl