In [22]:
sents = [
    [
        "the", "cat", "is", "a", "dog"
    ],
    [
        "I", "love", "to", "be", "an", "engineer"
    ],
    [
        "try", "me"
    ]
]
pad_token = '<pad>'
sents_padded = []

In [24]:

sentences = sents.copy()
max_length = max(len(s) for s in sentences)
for s in sentences:
    s.extend([pad_token] * (max_length - len(s)))
    sents_padded.append(s)
    print(s)
print(sents_padded)
    

['the', 'cat', 'is', 'a', 'dog', '<pad>']
['I', 'love', 'to', 'be', 'an', 'engineer']
['try', 'me', '<pad>', '<pad>', '<pad>', '<pad>']
[['the', 'cat', 'is', 'a', 'dog', '<pad>'], ['I', 'love', 'to', 'be', 'an', 'engineer'], ['try', 'me', '<pad>', '<pad>', '<pad>', '<pad>']]


In [3]:
import torch

In [4]:
import torch.nn as nn
import torch.nn.utils
import torch.nn.functional as F

In [6]:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
CS224N Spring 2024: Assignment 3
model_embeddings.py: Embeddings for the NMT model
Pencheng Yin <pcyin@cs.cmu.edu>
Sahil Chopra <schopra8@stanford.edu>
Anand Dhoot <anandd@stanford.edu>
Vera Lin <veralin@stanford.edu>
Siyan Li <siyanli@stanford.edu>
"""

import torch.nn as nn

class ModelEmbeddings(nn.Module):
    """
    Class that converts input words to their embeddings.
    """
    def __init__(self, embed_size, vocab):
        """
        Init the Embedding layers.

        @param embed_size (int): Embedding size (dimensionality)
        @param vocab (Vocab): Vocabulary object containing src and tgt languages
                              See vocab.py for documentation.
        """
        super(ModelEmbeddings, self).__init__()
        self.embed_size = embed_size

        # default values
        self.source = None
        self.target = None

        src_pad_token_idx = vocab.src['<pad>']
        tgt_pad_token_idx = vocab.tgt['<pad>']

        ### YOUR CODE HERE (~2 Lines)
        ### TODO - Initialize the following variables:
        ###     self.source (Embedding Layer for source language)
        ###     self.target (Embedding Layer for target langauge)
        ###
        ### Note:
        ###     1. `vocab` object contains two vocabularies:
        ###            `vocab.src` for source
        ###            `vocab.tgt` for target
        ###     2. You can get the length of a specific vocabulary by running:
        ###             `len(vocab.<specific_vocabulary>)`
        ###     3. Remember to include the padding token for the specific vocabulary
        ###        when creating your Embedding.
        ###
        ### Use the following docs to properly initialize these variables:
        ###     Embedding Layer:
        ###         https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html

        self.source = nn.Embedding(len(vocab.src), embed_size, padding_idx=src_pad_token_idx)
        self.target = nn.Embedding(len(vocab.tgt), embed_size, padding_idx=tgt_pad_token_idx)

        ### END YOUR CODE




In [5]:
embed_size = 5


In [7]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
CS224N Spring 2024: Assignment 3
vocab.py: Vocabulary Generation
Pencheng Yin <pcyin@cs.cmu.edu>
Sahil Chopra <schopra8@stanford.edu>
Vera Lin <veralin@stanford.edu>
Siyan Li <siyanli@stanford.edu>

Usage:
    vocab.py --train-src=<file> --train-tgt=<file> [options] VOCAB_FILE

Options:
    -h --help                  Show this screen.
    --train-src=<file>         File of training source sentences
    --train-tgt=<file>         File of training target sentences
    --size=<int>               vocab size [default: 50000]
    --freq-cutoff=<int>        frequency cutoff [default: 2]
"""

from collections import Counter
from docopt import docopt
from itertools import chain
import json
import torch
from typing import List
from utils import read_corpus, pad_sents
import sentencepiece as spm


class VocabEntry(object):
    """ Vocabulary Entry, i.e. structure containing either
    src or tgt language terms.
    """
    def __init__(self, word2id=None):
        """ Init VocabEntry Instance.
        @param word2id (dict): dictionary mapping words 2 indices
        """
        if word2id:
            self.word2id = word2id
        else:
            self.word2id = dict()
            self.word2id['<pad>'] = 0   # Pad Token
            self.word2id['<s>'] = 1 # Start Token
            self.word2id['</s>'] = 2    # End Token
            self.word2id['<unk>'] = 3   # Unknown Token
        self.unk_id = self.word2id['<unk>']
        self.id2word = {v: k for k, v in self.word2id.items()}

    def __getitem__(self, word):
        """ Retrieve word's index. Return the index for the unk
        token if the word is out of vocabulary.
        @param word (str): word to look up.
        @returns index (int): index of word 
        """
        return self.word2id.get(word, self.unk_id)

    def __contains__(self, word):
        """ Check if word is captured by VocabEntry.
        @param word (str): word to look up
        @returns contains (bool): whether word is contained    
        """
        return word in self.word2id

    def __setitem__(self, key, value):
        """ Raise error, if one tries to edit the VocabEntry.
        """
        raise ValueError('vocabulary is readonly')

    def __len__(self):
        """ Compute number of words in VocabEntry.
        @returns len (int): number of words in VocabEntry
        """
        return len(self.word2id)

    def __repr__(self):
        """ Representation of VocabEntry to be used
        when printing the object.
        """
        return 'Vocabulary[size=%d]' % len(self)

    def id2word(self, wid):
        """ Return mapping of index to word.
        @param wid (int): word index
        @returns word (str): word corresponding to index
        """
        return self.id2word[wid]

    def add(self, word):
        """ Add word to VocabEntry, if it is previously unseen.
        @param word (str): word to add to VocabEntry
        @return index (int): index that the word has been assigned
        """
        if word not in self:
            wid = self.word2id[word] = len(self)
            self.id2word[wid] = word
            return wid
        else:
            return self[word]

    def words2indices(self, sents):
        """ Convert list of words or list of sentences of words
        into list or list of list of indices.
        @param sents (list[str] or list[list[str]]): sentence(s) in words
        @return word_ids (list[int] or list[list[int]]): sentence(s) in indices
        """
        if type(sents[0]) == list:
            return [[self[w] for w in s] for s in sents]
        else:
            return [self[w] for w in sents]

    def indices2words(self, word_ids):
        """ Convert list of indices into words.
        @param word_ids (list[int]): list of word ids
        @return sents (list[str]): list of words
        """
        return [self.id2word[w_id] for w_id in word_ids]

    def to_input_tensor(self, sents: List[List[str]], device: torch.device) -> torch.Tensor:
        """ Convert list of sentences (words) into tensor with necessary padding for 
        shorter sentences.

        @param sents (List[List[str]]): list of sentences (words)
        @param device: device on which to load the tesnor, i.e. CPU or GPU

        @returns sents_var: tensor of (max_sentence_length, batch_size)
        """
        word_ids = self.words2indices(sents)
        sents_t = pad_sents(word_ids, self['<pad>'])
        sents_var = torch.tensor(sents_t, dtype=torch.long, device=device)
        return torch.t(sents_var)

    @staticmethod
    def from_corpus(corpus, size, freq_cutoff=2):
        """ Given a corpus construct a Vocab Entry.
        @param corpus (list[str]): corpus of text produced by read_corpus function
        @param size (int): # of words in vocabulary
        @param freq_cutoff (int): if word occurs n < freq_cutoff times, drop the word
        @returns vocab_entry (VocabEntry): VocabEntry instance produced from provided corpus
        """
        vocab_entry = VocabEntry()
        word_freq = Counter(chain(*corpus))
        valid_words = [w for w, v in word_freq.items() if v >= freq_cutoff]
        print('number of word types: {}, number of word types w/ frequency >= {}: {}'
              .format(len(word_freq), freq_cutoff, len(valid_words)))
        top_k_words = sorted(valid_words, key=lambda w: word_freq[w], reverse=True)[:size]
        for word in top_k_words:
            vocab_entry.add(word)
        return vocab_entry

    @staticmethod
    def from_subword_list(subword_list):
        vocab_entry = VocabEntry()
        for subword in subword_list:
            vocab_entry.add(subword)
        return vocab_entry


class Vocab(object):
    """ Vocab encapsulating src and target langauges.
    """
    def __init__(self, src_vocab: VocabEntry, tgt_vocab: VocabEntry):
        """ Init Vocab.
        @param src_vocab (VocabEntry): VocabEntry for source language
        @param tgt_vocab (VocabEntry): VocabEntry for target language
        """
        self.src = src_vocab
        self.tgt = tgt_vocab

    @staticmethod
    def build(src_sents, tgt_sents) -> 'Vocab':
        """ Build Vocabulary.
        @param src_sents (list[str]): Source subwords provided by SentencePiece
        @param tgt_sents (list[str]): Target subwords provided by SentencePiece
        """
        # assert len(src_sents) == len(tgt_sents)

        print('initialize source vocabulary ..')
        # src = VocabEntry.from_corpus(src_sents, vocab_size, freq_cutoff)
        src = VocabEntry.from_subword_list(src_sents)

        print('initialize target vocabulary ..')
        # tgt = VocabEntry.from_corpus(tgt_sents, vocab_size, freq_cutoff)
        tgt = VocabEntry.from_subword_list(tgt_sents)

        return Vocab(src, tgt)

    def save(self, file_path):
        """ Save Vocab to file as JSON dump.
        @param file_path (str): file path to vocab file
        """
        with open(file_path, 'w') as f:
            json.dump(dict(src_word2id=self.src.word2id, tgt_word2id=self.tgt.word2id), f, indent=2)

    @staticmethod
    def load(file_path):
        """ Load vocabulary from JSON dump.
        @param file_path (str): file path to vocab file
        @returns Vocab object loaded from JSON dump
        """
        entry = json.load(open(file_path, 'r'))
        src_word2id = entry['src_word2id']
        tgt_word2id = entry['tgt_word2id']

        return Vocab(VocabEntry(src_word2id), VocabEntry(tgt_word2id))

    def __repr__(self):
        """ Representation of Vocab to be used
        when printing the object.
        """
        return 'Vocab(source %d words, target %d words)' % (len(self.src), len(self.tgt))


def get_vocab_list(file_path, source, vocab_size):
    """ Use SentencePiece to tokenize and acquire list of unique subwords.
    @param file_path (str): file path to corpus
    @param source (str): tgt or src
    @param vocab_size: desired vocabulary size
    """
    spm.SentencePieceTrainer.Train(input=file_path, model_prefix=source, vocab_size=vocab_size)     # train the spm model
    sp = spm.SentencePieceProcessor()   # create an instance; this saves .model and .vocab files 
    sp.Load('{}.model'.format(source))  # loads tgt.model or src.model
    sp_list = [sp.IdToPiece(piece_id) for piece_id in range(sp.GetPieceSize())] # this is the list of subwords
    return sp_list



if __name__ == '__main__':
    args = docopt(__doc__)

    print('read in source sentences: %s' % args['--train-src'])
    print('read in target sentences: %s' % args['--train-tgt'])

    src_sents = get_vocab_list(args['--train-src'], source='src', vocab_size=21000)          # EDIT: NEW VOCAB SIZE
    tgt_sents = get_vocab_list(args['--train-tgt'], source='tgt', vocab_size=8000)
    vocab = Vocab.build(src_sents, tgt_sents)
    print('generated vocabulary, source %d words, target %d words' % (len(src_sents), len(tgt_sents)))

    # src_sents = read_corpus(args['--train-src'], source='src')
    # tgt_sents = read_corpus(args['--train-tgt'], source='tgt')

    # vocab = Vocab.build(src_sents, tgt_sents, int(args['--size']), int(args['--freq-cutoff']))
    # print('generated vocabulary, source %d words, target %d words' % (len(vocab.src), len(vocab.tgt)))

    vocab.save(args['VOCAB_FILE'])
    print('vocabulary saved to %s' % args['VOCAB_FILE'])


ModuleNotFoundError: No module named 'utils'

In [28]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
CS224N 2022-23: Homework 4
utils.py: Utility Functions
Pencheng Yin <pcyin@cs.cmu.edu>
Sahil Chopra <schopra8@stanford.edu>
Vera Lin <veralin@stanford.edu>
Siyan Li <siyanli@stanford.edu>
Moussa KB Doumbouya <moussa@stanford.edu>
"""

from typing import List
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import nltk
import sentencepiece as spm
nltk.download('punkt')


def pad_sents(sents, pad_token):
    """ Pad list of sentences according to the longest sentence in the batch.
        The paddings should be at the end of each sentence.
    @param sents (list[list[str]]): list of sentences, where each sentence
                                    is represented as a list of words
    @param pad_token (str): padding token
    @returns sents_padded (list[list[str]]): list of sentences where sentences shorter
        than the max length sentence are padded out with the pad_token, such that
        each sentences in the batch now has equal length.
    """
    sents_padded = []

    ### YOUR CODE HERE (~6 Lines)

    sentences = sents.copy()
    max_length = max(len(s) for s in sentences)
    for s in sentences:
        s.extend([pad_token] * (max_length - len(s)))
        sents_padded.append(s)

    ### END YOUR CODE

    return sents_padded


def read_corpus(file_path, source, vocab_size=2500):
    """ Read file, where each sentence is dilineated by a `\n`.
    @param file_path (str): path to file containing corpus
    @param source (str): "tgt" or "src" indicating whether text
        is of the source language or target language
    @param vocab_size (int): number of unique subwords in
        vocabulary when reading and tokenizing
    """
    data = []
    sp = spm.SentencePieceProcessor()
    sp.load('{}.model'.format(source))

    with open(file_path, 'r', encoding='utf8') as f:
        for line in f:
            subword_tokens = sp.encode_as_pieces(line)
            # only append <s> and </s> to the target sentence
            if source == 'tgt':
                subword_tokens = ['<s>'] + subword_tokens + ['</s>']
            data.append(subword_tokens)

    return data


def autograder_read_corpus(file_path, source):
    """ Read file, where each sentence is dilineated by a `\n`.
    @param file_path (str): path to file containing corpus
    @param source (str): "tgt" or "src" indicating whether text
        is of the source language or target language
    """
    data = []
    for line in open(file_path):
        sent = nltk.word_tokenize(line)
        # only append <s> and </s> to the target sentence
        if source == 'tgt':
            sent = ['<s>'] + sent + ['</s>']
        data.append(sent)

    return data


def batch_iter(data, batch_size, shuffle=False):
    """ Yield batches of source and target sentences reverse sorted by length (largest to smallest).
    @param data (list of (src_sent, tgt_sent)): list of tuples containing source and target sentence
    @param batch_size (int): batch size
    @param shuffle (boolean): whether to randomly shuffle the dataset
    """
    batch_num = math.ceil(len(data) / batch_size)
    index_array = list(range(len(data)))

    if shuffle:
        np.random.shuffle(index_array)

    for i in range(batch_num):
        indices = index_array[i * batch_size: (i + 1) * batch_size]
        examples = [data[idx] for idx in indices]

        examples = sorted(examples, key=lambda e: len(e[0]), reverse=True)
        src_sents = [e[0] for e in examples]
        tgt_sents = [e[1] for e in examples]

        yield src_sents, tgt_sents




[nltk_data] Downloading package punkt to /Users/Paul/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [29]:
sents = [['Comencemos', 'por', 'pensar', 'en', 'los', 'pases', 'miembros', 'de', 'la', 'OCDE', ',', 'o', 'la', 'Organizacin', 'para', 'la', 'Cooperacin', 'y', 'el', 'Desarrollo', 'Econmicos', '.'], ['En', 'el', 'caso', 'de', 'control', 'de', 'armas', ',', 'realmente', 'subestimamos', 'a', 'nuestros', 'rivales', '.'], ['Sugiere', 'que', 'nos', 'interesa', 'el', 'combate', ',', 'el', 'desafo', '.'], ['Djenme', 'compartir', 'con', 'ustedes', 'aqu', 'en', 'la', 'primera', 'fila', '.'], ['Con', 'muchos', 'nmeros', '.', 'Un', 'montn']]


In [31]:
len(sents[1])

22

In [30]:
print(pad_sents(sents, '<pad>'))

[['Comencemos', 'por', 'pensar', 'en', 'los', 'pases', 'miembros', 'de', 'la', 'OCDE', ',', 'o', 'la', 'Organizacin', 'para', 'la', 'Cooperacin', 'y', 'el', 'Desarrollo', 'Econmicos', '.'], ['En', 'el', 'caso', 'de', 'control', 'de', 'armas', ',', 'realmente', 'subestimamos', 'a', 'nuestros', 'rivales', '.', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>'], ['Sugiere', 'que', 'nos', 'interesa', 'el', 'combate', ',', 'el', 'desafo', '.', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>'], ['Djenme', 'compartir', 'con', 'ustedes', 'aqu', 'en', 'la', 'primera', 'fila', '.', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>'], ['Con', 'muchos', 'nmeros', '.', 'Un', 'montn', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']]


In [32]:
x = torch.randn(2, 3, 5)
x.size()

torch.Size([2, 3, 5])

In [33]:
torch.permute(x, (2, 0, 1)).size()

torch.Size([5, 2, 3])

In [36]:
last_hidden = torch.tensor([[[0.1485, 0.1485],
                      [0.1565, 0.1565],
                      [0.1553, 0.1553],
                      [0.1553, 0.1553],
                      [0.1476, 0.1476]],

                     [[0.1567, 0.1567],
                      [0.1565, 0.1565],
                      [0.1553, 0.1553],
                      [0.1553, 0.1553],
                      [0.1476, 0.1476]]])
# last_hidden.shape: torch.Size([2, 5, 2])
last_cell = torch.tensor([[[0.2753, 0.2753],
                    [0.2882, 0.2882],
                    [0.2860, 0.2860],
                    [0.2860, 0.2860],
                    [0.2714, 0.2714]],

                   [[0.2886, 0.2886],
                    [0.2882, 0.2882],
                    [0.2860, 0.2860],
                    [0.2860, 0.2860],
                    [0.2714, 0.2714]]])


In [40]:
last_hidden[0]

tensor([[0.1485, 0.1485],
        [0.1565, 0.1565],
        [0.1553, 0.1553],
        [0.1553, 0.1553],
        [0.1476, 0.1476]])

In [41]:
last_hidden[1]

tensor([[0.1567, 0.1567],
        [0.1565, 0.1565],
        [0.1553, 0.1553],
        [0.1553, 0.1553],
        [0.1476, 0.1476]])

In [43]:
torch.cat((last_hidden[0], last_hidden[1]), 1)

tensor([[0.1485, 0.1485, 0.1567, 0.1567],
        [0.1565, 0.1565, 0.1565, 0.1565],
        [0.1553, 0.1553, 0.1553, 0.1553],
        [0.1553, 0.1553, 0.1553, 0.1553],
        [0.1476, 0.1476, 0.1476, 0.1476]])

In [44]:
Y = torch.tensor([[[0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500]],

           [[0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500]],

           [[0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500]],

           [[0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500]],

           [[0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500]],

           [[0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500]],

           [[0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500]],

           [[0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500]],

           [[0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500]],

           [[0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500]],

           [[0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500]],

           [[0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500]],

           [[0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500]],

           [[0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500]],

           [[0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500]],

           [[0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500]],

           [[0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500]],

           [[0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500]],

           [[0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500]],

           [[0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500]],

           [[0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500]],

           [[0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500]],

           [[0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500],
            [0.1500, 0.1500, 0.1500]]])

In [46]:
print(f'Y.shape = {Y.shape}')

Y.shape = torch.Size([23, 5, 3])


In [50]:
for Y_t in torch.split(Y,1):
    Y_t = torch.squeeze(Y_t, dim=0)
    print(f'Y_t.shape = {Y_t.shape}')

Y_t.shape = torch.Size([5, 3])
Y_t.shape = torch.Size([5, 3])
Y_t.shape = torch.Size([5, 3])
Y_t.shape = torch.Size([5, 3])
Y_t.shape = torch.Size([5, 3])
Y_t.shape = torch.Size([5, 3])
Y_t.shape = torch.Size([5, 3])
Y_t.shape = torch.Size([5, 3])
Y_t.shape = torch.Size([5, 3])
Y_t.shape = torch.Size([5, 3])
Y_t.shape = torch.Size([5, 3])
Y_t.shape = torch.Size([5, 3])
Y_t.shape = torch.Size([5, 3])
Y_t.shape = torch.Size([5, 3])
Y_t.shape = torch.Size([5, 3])
Y_t.shape = torch.Size([5, 3])
Y_t.shape = torch.Size([5, 3])
Y_t.shape = torch.Size([5, 3])
Y_t.shape = torch.Size([5, 3])
Y_t.shape = torch.Size([5, 3])
Y_t.shape = torch.Size([5, 3])
Y_t.shape = torch.Size([5, 3])
Y_t.shape = torch.Size([5, 3])
