In [None]:

def introduce(name, age, city):
  print(f"Hello, my name is {name}. I am {age} years old and live in {city}.")


person_info = ["Alice", 30, "New York"]




# ---------------------------------------------------------------------
# Scenario A: The WRONG way (without the * operator)
# ---------------------------------------------------------------------

#introduce(person_info)



# ---------------------------------------------------------------------
# Scenario B: The RIGHT way (with the * operator)
# ---------------------------------------------------------------------

introduce(*person_info)

Hello, my name is Alice. I am 30 years old and live in New York.


In [None]:
from itertools import chain

sample = {
    "input_ids": [
        [101, 4983, 102],
        [101, 2023, 2003, 102],
        [101, 2759, 2007, 4937, 102]
    ],
    "attention_mask": [
        [1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1, 1]
    ]
}

print(sample)
print("\n" + "="*50 + "\n")

concatenated_examples = {k: list(chain(*sample[k])) for k in sample.keys()}



print(concatenated_examples)

{'input_ids': [[101, 4983, 102], [101, 2023, 2003, 102], [101, 2759, 2007, 4937, 102]], 'attention_mask': [[1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1]]}


{'input_ids': [101, 4983, 102, 101, 2023, 2003, 102, 101, 2759, 2007, 4937, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


In [None]:
a = [[101, 4983, 102], [101, 2023, 2003, 102], [101, 2759, 2007, 4937, 102]]
a = list(chain(*a))
a

[101, 4983, 102, 101, 2023, 2003, 102, 101, 2759, 2007, 4937, 102]

In [None]:
from itertools import chain


remainder = {"input_ids": [], "attention_mask": []}
print("--- INITIAL STATE ---")
print(f"Initial Remainder: {remainder}\n")


# ==========================================================
# --- ROUND 1 ---
# ==========================================================
print("="*20 + " ROUND 1 " + "="*20)


batch_1 = {
    "input_ids": [
        [1, 2, 3],
        [4, 5, 6, 7],
        [8, 9, 10, 11, 12]
    ],
    "attention_mask": [
        [1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1, 1]
    ]
}
print("1. Input for Round 1:", batch_1)

concatenated_examples = {k: list(chain(*batch_1[k])) for k in batch_1.keys()}
print("\n2. After flattening Batch 1:", concatenated_examples)

# This is the second line: prepending the (currently empty) remainder
final_stream = {
    k: remainder[k] + concatenated_examples[k] for k in concatenated_examples.keys()
}
print("\n3. Final stream after adding remainder:", final_stream)


# --- For demonstration, let's calculate what the remainder WOULD be ---
# We'll use a chunk_length of 10. The stream has 12 tokens.
# So, the remainder will be the last 2 tokens.
remainder = {k: v[10:] for k, v in final_stream.items()}

print("\n4. New remainder calculated for next round:", remainder)


# ==========================================================
# --- ROUND 2 ---
# ==========================================================
print("\n\n" + "="*20 + " ROUND 2 " + "="*20)
batch_2 = {
    "input_ids": [
        [21, 22, 23, 24, 25],
        [26, 27, 28, 29, 30]
    ],
    "attention_mask": [
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1]
    ]
}
print("1. Input for Round 2:", batch_2)
print(f"   (Note: The current remainder from Round 1 is: {remainder})")


# First line again: flattening Batch 2
concatenated_examples = {k: list(chain(*batch_2[k])) for k in batch_2.keys()}
print("\n2. After flattening Batch 2:", concatenated_examples)

# Second line again: prepending the remainder FROM ROUND 1
final_stream = {
    k: remainder[k] + concatenated_examples[k] for k in concatenated_examples.keys()
}
print("\n3. Final stream after adding remainder:", final_stream)

--- INITIAL STATE ---
Initial Remainder: {'input_ids': [], 'attention_mask': []}

1. Input for Round 1: {'input_ids': [[1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11, 12]], 'attention_mask': [[1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1, 1]]}

2. After flattening Batch 1: {'input_ids': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

3. Final stream after adding remainder: {'input_ids': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

4. New remainder calculated for next round: {'input_ids': [11, 12], 'attention_mask': [1, 1]}


1. Input for Round 2: {'input_ids': [[21, 22, 23, 24, 25], [26, 27, 28, 29, 30]], 'attention_mask': [[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]}
   (Note: The current remainder from Round 1 is: {'input_ids': [11, 12], 'attention_mask': [1, 1]})

2. After flattening Batch 2: {'input_ids': [21, 22, 23, 24, 25, 26, 27, 28, 29, 30], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

3. Final s

In [None]:
# --- SETUP ---
# Let's use a small chunk_length for this example.
chunk_length = 5


# This is our "Before" state. A dictionary where each key holds one long,
# continuous list of tokens.
concatenated_examples = {
    "input_ids":      [10, 11, 12, 13, 14, 20, 21, 22, 23, 24, 30, 31, 32, 33, 34],
    "attention_mask": [1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1]
}

# This is the total number of tokens we will slice up.
batch_chunk_length = len(concatenated_examples['attention_mask']) #15


print("--- Before Slicing ---")
print("This is our long, continuous stream of tokens:")
print(concatenated_examples)
print("\n" + "="*50 + "\n")


# --- THE OPERATION ---
# This is the line of code that performs the slicing.range(start,stop,step)--->(0,15,5) i = 0,5,10
result = {
    k: [t[i : i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)] for k, t in concatenated_examples.items()
}


# --- THE RESULT ---
print("--- After Slicing ---")
print("The stream has been chopped into a list of smaller lists (chunks):")
print(result)

--- Before Slicing ---
This is our long, continuous stream of tokens:
{'input_ids': [10, 11, 12, 13, 14, 20, 21, 22, 23, 24, 30, 31, 32, 33, 34], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


--- After Slicing ---
The stream has been chopped into a list of smaller lists (chunks):
{'input_ids': [[10, 11, 12, 13, 14], [20, 21, 22, 23, 24], [30, 31, 32, 33, 34]], 'attention_mask': [[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]}


In [None]:
# This is what the new function does conceptually
def new_function(sample):
    # The chunk_length is already "frozen" to 2048
    return chunk(sample, chunk_length=2048)

input_ids
[10, 11, 12, 13, 14, 20, 21, 22, 23, 24, 30, 31, 32, 33, 34]
attention_mask
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
